Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fed_reader

This commit is contained in:
Erik Johnston 2016-07-29 11:24:56 +01:00
commit 5aa024e501
54 changed files with 1256 additions and 351 deletions

View File

@ -1,3 +1,67 @@
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
This release contains significant security bug fixes regarding authenticating
events received over federation. Please upgrade.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================

View File

@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs:
1) Use the machine's own hostname as available on public DNS in the form of
its A or AAAA records. This is easier to set up initially, perhaps for
its A records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV

12
docs/admin_api/README.rst Normal file
View File

@ -0,0 +1,12 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@ -0,0 +1,15 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@ -0,0 +1,19 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@ -16,7 +16,5 @@ ignore =
[flake8]
max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.16.1-r1"
__version__ = "0.17.0-rc1"

View File

@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
import synapse.types
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -376,6 +376,10 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
@ -566,8 +570,7 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
defer.Deferred: resolves to a namedtuple including "user" (UserID)
"access_token_id" (int), "is_guest" (bool)
defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@ -576,9 +579,7 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights)
@ -612,7 +613,8 @@ class Auth(object):
request.authenticated_entity = user.to_string()
defer.returnValue(Requester(user, token_id, is_guest))
defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",

View File

@ -16,13 +16,11 @@
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
from synapse import python_dependencies # noqa: E402
try:
check_requirements()
except MissingRequirementError as e:
python_dependencies.check_requirements()
except python_dependencies.MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",

View File

@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
self._peer = None
def connectionMade(self):
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host)
logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
protocol.host = self.host
return protocol

View File

@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class Keyring(object):
@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
verify_requests = []
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
deferreds[group_id] = defer.fail(SynapseError(
deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
deferred = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
verify_requests.append(verify_request)
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
def handle_key_deferred(verify_request):
server_name = verify_request.server_name
try:
_, _, key_id, verify_key = yield deferred
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
server_to_request_ids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
def remove_deferreds(res, server_name, verify_request):
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
preserve_context_over_fn(handle_key_deferred, verify_request)
for verify_request in verify_requests
]
@defer.inlineCallbacks
@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@ -237,62 +242,64 @@ class Keyring(object):
merged_results = {}
missing_keys = {}
for group in group_id_to_group.values():
missing_keys.setdefault(group.server_name, set()).update(
group.key_ids
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
verify_request.deferred.callback((
server_name,
key_id,
merged_results[group.server_name][key_id],
result_keys[key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_groups:
if not missing_keys:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
for verify_request in verify_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
@ -447,7 +454,7 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
perspective_name, response
perspective_name, response, only_from_server=False
)
for server_name, response_keys in processed_response.items():
@ -527,7 +534,7 @@ class Keyring(object):
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@ -551,6 +558,13 @@ class Keyring(object):
results = {}
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise ValueError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
import synapse.types
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
import logging
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
logger = logging.getLogger(__name__)
@ -124,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
requester = Requester(target_user, "", True)
requester = synapse.types.create_requester(
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,

View File

@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@ -279,7 +280,16 @@ class AuthHandler(BaseHandler):
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@ -365,7 +375,8 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None):
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
@ -374,9 +385,15 @@ class AuthHandler(BaseHandler):
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str): the device ID to associate with the access token
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
A tuple of:
The access token for the user's session.
@ -388,6 +405,16 @@ class AuthHandler(BaseHandler):
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks

View File

@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler):
Args:
user_id (str):
Returns:
defer.Deferred: dict[str, dict[str, X]]: map from device_id to
info on the device
defer.Deferred: list[dict[str, X]]: info on each device
"""
devices = yield self.store.get_devices_by_user(user_id)
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in devices.keys())
devices=((user_id, device_id) for device_id in device_map.keys())
)
for device in devices.values():
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
Args:
user_id (str):
device_id (str)
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
@ -117,6 +117,61 @@ class DeviceHandler(BaseHandler):
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})

View File

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester
from synapse.types import UserID
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
requester = Requester(user, "", False)
# XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership(
requester,
user,

View File

@ -14,18 +14,19 @@
# limitations under the License.
"""Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer
from synapse.types import UserID, Requester
import synapse.types
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import logging
import urllib
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname
user, requester, displayname
)
defer.returnValue((user_id, token))

View File

@ -14,24 +14,22 @@
# limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)

View File

@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback)
)

View File

@ -140,9 +140,8 @@ class EmailPusher(object):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, start, self.max_stream_ordering
)
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
soonest_due_at = None

View File

@ -141,7 +141,8 @@ class HttpPusher(object):
run once per pusher.
"""
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)

View File

@ -54,7 +54,7 @@ def get_context_for_event(state_handler, ev, user_id):
room_state = yield state_handler.get_current_state(ev.room_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.namer, an alias or
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False

View File

@ -93,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__

View File

@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {

View File

@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(RestServlet):
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)

View File

@ -13,24 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload/<device_id> HTTP/1.1
POST /keys/upload HTTP/1.1
Content-Type: application/json
{
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
},
}
"""
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
else:
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
@ -102,13 +125,14 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(
user_id, device_id, "unknown device"
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))

View File

@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
"""
device_id = yield self._register_device(user_id, params)
access_token = yield self.auth_handler.issue_access_token(
user_id, device_id=device_id
access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
refresh_token = yield self.auth_handler.issue_refresh_token(
user_id, device_id=device_id
)
defer.returnValue({
"user_id": user_id,
"access_token": access_token,

View File

@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": ["r0.0.1"]
"versions": [
"r0.0.1",
"r0.1.0",
"r0.2.0",
]
})

View File

@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
from . import engines
from twisted.internet import defer
@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def start_doing_background_updates(self):
while True:
if self._background_update_timer is not None:
return
assert self._background_update_timer is None, \
"background updates already running"
logger.info("Starting background schema updates")
while True:
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None
try:
result = yield self.do_background_update(
result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
else:
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
return
defer.returnValue(None)
@defer.inlineCallbacks
def do_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
Args:
desired_duration_ms(float): How long we want to spend
updating.
@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
# no work left to do
defer.returnValue(None)
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms)
defer.returnValue(res)
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
logger.info("Starting update batch on background update '%s'",
update_name)
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
def register_background_index_update(self, update_name, index_name,
table, columns):
"""Helper for store classes to do a background index addition
To use:
1. use a schema delta file to add a background update. Example:
INSERT INTO background_updates (update_name, progress_json) VALUES
('my_new_index', '{}');
2. In the Store constructor, call this method
Args:
update_name (str): update_name to register for
index_name (str): name of index to add
table (str): table to add index to
columns (list[str]): columns/expressions to include in index
"""
# if this is postgres, we add the indexes concurrently. Otherwise
# we fall back to doing it inline
if isinstance(self.database_engine, engines.PostgresEngine):
conc = True
else:
conc = False
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
% {
"conc": "CONCURRENTLY" if conc else "",
"name": index_name,
"table": table,
"columns": ", ".join(columns),
}
def create_index_concurrently(conn):
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
c = conn.cursor()
c.execute(sql)
conn.set_session(autocommit=False)
def create_index(conn):
c = conn.cursor()
c.execute(sql)
@defer.inlineCallbacks
def updater(progress, batch_size):
logger.info("Adding index %s to %s", index_name, table)
if conc:
yield self.runWithConnection(create_index_concurrently)
else:
yield self.runWithConnection(create_index)
yield self._end_background_update(update_name)
defer.returnValue(1)
self.register_background_update_handler(update_name, updater)
def start_background_update(self, update_name, progress):
"""Starts a background update running.

View File

@ -15,10 +15,11 @@
import logging
from ._base import SQLBaseStore, Cache
from twisted.internet import defer
from ._base import Cache
from . import background_updates
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
@ -27,8 +28,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(SQLBaseStore):
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
@ -37,6 +37,13 @@ class ClientIpStore(SQLBaseStore):
super(ClientIpStore, self).__init__(hs)
self.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())

View File

@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
desc="get_device",
)
def delete_device(self, user_id, device_id):
"""Delete a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to delete
Returns:
defer.Deferred
"""
return self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import twisted.internet.defer
from ._base import SQLBaseStore
@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
@twisted.internet.defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete(
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_device_keys_by_device"
)
yield self._simple_delete(
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device"
)

View File

@ -117,21 +117,149 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering,
max_stream_ordering=None,
limit=20):
def get_unread_push_actions_for_user_in_range_for_http(
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
):
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
# find rooms that have a read receipt in them and return the next
# push actions
def get_after_receipt(txn):
# find rooms that have a read receipt in them and return the next
# push actions
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
" WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
" OR ("
" ep.topological_ordering = rl.topological_ordering"
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id NOT IN ("
" SELECT room_id FROM receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
" )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
notifs = [
{
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
"actions": json.loads(row[3]),
} for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
notifs.sort(key=lambda r: r['stream_ordering'])
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
):
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
Args:
user_id (str): The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the
stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the
stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return.
Returns:
A promise which resolves to a list of dicts with the keys "event_id",
"room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
# find rooms that have a read receipt in them and return the most recent
# push actions
def get_after_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
"e.received_ts "
"FROM ("
" SELECT room_id, user_id, "
" max(topological_ordering) as topological_ordering, "
" max(stream_ordering) as stream_ordering "
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM ("
" SELECT room_id,"
" MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
" GROUP BY room_id, user_id"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
") AS rl,"
" event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
@ -144,44 +272,49 @@ class EventPushActionsStore(SQLBaseStore):
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
" AND ep.user_id = rl.user_id"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [min_stream_ordering, user_id]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
args.append(limit)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_after_receipt
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
# There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too.
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.room_id not in ("
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id NOT IN ("
" SELECT room_id FROM receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?"
" )"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
args.append(limit)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
# Make a list of dicts from the two sets of results.
@ -198,7 +331,7 @@ class EventPushActionsStore(SQLBaseStore):
# Now sort it so it's ordered correctly, since currently it will
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts
# by received_ts (most recent first)
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
# Now return the first `limit`

View File

@ -397,6 +397,12 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
"""
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@ -407,21 +413,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering,
)
if not event.internal_metadata.is_outlier():
if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@ -431,14 +427,24 @@ class EventsStore(SQLBaseStore):
),
[event.event_id for event, _ in events_and_contexts]
)
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
# Remove the events that we've seen before.
event_map = {}
to_remove = set()
for event, context in events_and_contexts:
if context.rejected:
# If the event is rejected then we don't care if the event
# was an outlier or not.
if event.event_id in have_persisted:
# If we have already seen the event then ignore it.
to_remove.add(event)
continue
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
@ -463,6 +469,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
@ -478,6 +490,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
@ -499,6 +513,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_extremeties(txn, [event])
events_and_contexts = [
@ -506,38 +522,12 @@ class EventsStore(SQLBaseStore):
]
if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return
self._store_mult_state_groups_txn(txn, events_and_contexts)
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
# From this point onwards the events are only events that we haven't
# seen before.
def event_dict(event):
return {
@ -591,10 +581,41 @@ class EventsStore(SQLBaseStore):
],
)
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
self._store_rejections_txn(
txn, event.event_id, context.rejected
)
to_remove.add(event)
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return
# From this point onwards the events are only ones that weren't rejected.
for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table.
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
self._simple_insert_many_txn(
txn,
@ -610,6 +631,49 @@ class EventsStore(SQLBaseStore):
],
)
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
# Insert into the room_names and event_search tables.
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
# Insert into the topics table and event_search table.
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
# Insert into the event_search table.
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
# Insert into the redactions table.
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
# Insert into the event_search table.
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
# Insert event_reference_hashes table.
self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts]
)
@ -654,6 +718,7 @@ class EventsStore(SQLBaseStore):
],
)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
if backfilled:
@ -666,11 +731,6 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state.
continue
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)

View File

@ -18,18 +18,31 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore):
class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
self.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@ -238,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
def f(txn):
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None,
delete_refresh_tokens=False):
"""
Invalidate access/refresh tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
if except_token_ids:
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]),
",".join(["?" for _ in except_tokens]),
)
clauses += except_token_ids
clauses += except_tokens
txn.execute(sql, clauses)
@ -256,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM access_tokens WHERE token in (%s)" % (
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
yield self.runInteraction("user_delete_access_tokens", f)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
def f(txn):
@ -288,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
dict: Including the name (user_id) and the ID of their access token.
Raises:
StoreError if no user was found.
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('access_tokens_device_index', '{}');

View File

@ -0,0 +1,19 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- make sure that we have a device record for each set of E2E keys, so that the
-- user can delete them if they like.
INSERT INTO devices
SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json;

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View File

@ -13,4 +13,5 @@
* limitations under the License.
*/
CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen);
INSERT INTO background_updates (update_name, progress_json) VALUES
('user_ips_device_index', '{}');

View File

@ -24,6 +24,7 @@ from collections import namedtuple
import itertools
import logging
import ujson as json
logger = logging.getLogger(__name__)
@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
return result["response_code"], result["response_json"]
return result["response_code"], json.loads(str(result["response_json"]))
else:
return None

View File

@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
Requester = namedtuple("Requester",
["user", "access_token_id", "is_guest", "device_id"])
"""
Represents the user making a request
Attributes:
user (UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
"""
def create_requester(user_id, access_token_id=None, is_guest=False,
device_id=None):
"""
Create a new ``Requester`` object
Args:
user_id (str|UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id)
def get_domain_from_id(string):

View File

@ -84,7 +84,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
)
return

View File

@ -83,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
):
if ("m.room.member", my_member_event.sender) in room_state:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
if fallback_to_single_member:
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
else:
return None
else:
return "Room Invite"

View File

@ -128,7 +128,7 @@ class RetryDestinationLimiter(object):
)
valid_err_code = False
if exc_type is CodeMessageException:
if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:

View File

@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse import types
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from synapse import types
from tests import unittest, utils
user1 = "@boris:aaa"
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.handler = None # type: device.DeviceHandler
self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
@ -84,28 +87,31 @@ class DeviceTestCase(unittest.TestCase):
yield self._record_users()
res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res.keys()))
self.assertEqual(3, len(res))
device_map = {
d["device_id"]: d for d in res
}
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "xyz",
"display_name": "display 0",
"last_seen_ip": None,
"last_seen_ts": None,
}, res["xyz"])
}, device_map["xyz"])
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "fco",
"display_name": "display 1",
"last_seen_ip": "ip1",
"last_seen_ts": 1000000,
}, res["fco"])
}, device_map["fco"])
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
}, res["abc"])
}, device_map["abc"])
@defer.inlineCallbacks
def test_get_device(self):
@ -120,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
"last_seen_ts": 3000000,
}, res)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
res = yield self.handler.get_device(user1, "abc")
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id",
update)
@defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,

View File

@ -19,11 +19,12 @@ from twisted.internet import defer
from mock import Mock, NonCallableMock
import synapse.types
from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
from tests.utils import setup_test_homeserver, requester_for_user
from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank,
requester_for_user(self.frank),
synapse.types.create_requester(self.frank),
"Frank Jr."
)
@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank,
requester_for_user(self.bob),
synapse.types.create_requester(self.bob),
"Frank Jr."
)
@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
self.frank, synapse.types.create_requester(self.frank),
"http://my.server/pic.gif"
)
self.assertEquals(

View File

@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.resource import ReplicationResource
from synapse.types import Requester, UserID
from twisted.internet import defer
from tests import unittest
from tests.utils import setup_test_homeserver, requester_for_user
from mock import Mock, NonCallableMock
import json
import contextlib
import json
from mock import Mock, NonCallableMock
from twisted.internet import defer
import synapse.types
from synapse.replication.resource import ReplicationResource
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase):
@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {}
synapse.types.create_requester(self.user), {}
)
code, body = yield get
self.assertEquals(code, 200)
@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event(
requester_for_user(self.user),
synapse.types.create_requester(self.user),
{
"type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"},
@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {}
synapse.types.create_requester(self.user), {}
)
defer.returnValue(result["room_id"])

View File

@ -14,17 +14,14 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
from tests import unittest
from mock import Mock
from twisted.internet import defer
from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
import synapse.types
from synapse.api.errors import SynapseError, AuthError
from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1"
@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)
return synapse.types.create_requester(myid)
hs.get_v1auth().get_user_by_req = _get_user_by_req

View File

@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
self.auth_handler.issue_access_token = Mock(return_value=token)
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.issue_access_token = Mock(return_value=token)
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
self.device_handler.check_device_registered = \
Mock(return_value=device_id)
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = {
"user_id": user_id,
"access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.issue_access_token.assert_called_once_with(
user_id, device_id=device_id)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False

View File

@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver()
hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update", self.update_handler
)
# run the real background updates, to get them out the way
# (perhaps we should run them as part of the test HS setup, since we
# run all of the other schema setup stuff there?)
while True:
res = yield self.store.do_next_background_update(1000)
if res is None:
break
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000
duration_ms = 42
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 2}, desired_count
)
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)

View File

@ -15,6 +15,7 @@
from twisted.internet import defer
import synapse.api.errors
import tests.unittest
import tests.utils
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"device_id": "device2",
"display_name": "display_name 2",
}, res["device2"])
@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield self.store.update_device(
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
"user_id", "device_id",
new_display_name="display_name 2",
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
"user_id", "unknown_device_id",
new_display_name="display_name 2",
)
self.assertEqual(404, cm.exception.code)

View File

@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
USER_ID = "@user:example.com"
class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
yield self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
@defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
yield self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)

View File

@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
generator = TokenGenerator()
refresh_token = generator.generate(self.user_id)
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
self.device_id)
# now delete some
yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertIsNone(user, "access token was not deleted by device_id")
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(refresh_token,
generator.generate)
# check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertEqual(self.user_id, user["name"])
# now delete the rest
yield self.store.user_delete_access_tokens(
self.user_id, delete_refresh_tokens=True)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,
"access token was not deleted without device_id")
class TokenGenerator:
def __init__(self):

View File

@ -17,13 +17,18 @@ from twisted.trial import unittest
import logging
# logging doesn't have a "don't log anything at all EVARRRR setting,
# but since the highest value is 50, 1000000 should do ;)
NEVER = 1000000
logging.getLogger().addHandler(logging.StreamHandler())
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
))
logging.getLogger().addHandler(handler)
logging.getLogger().setLevel(NEVER)
logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
logging.getLogger("synapse.storage.txn").setLevel(NEVER)
def around(target):
@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
return ret
logging.getLogger().setLevel(level)
# Don't set SQL logging
logging.getLogger("synapse.storage").setLevel(old_level)
return orig()
def assertObjectHasAttributes(self, attrs, obj):

View File

@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@ -512,7 +511,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
def requester_for_user(user):
return Requester(user, None, False)