Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2018-08-16 15:51:16 +01:00
commit bb795b56da
24 changed files with 502 additions and 55 deletions

View File

@ -56,17 +56,17 @@ entry. These are managed by Towncrier
(https://github.com/hawkowl/towncrier).
To create a changelog entry, make a new file in the ``changelog.d``
file named in the format of ``issuenumberOrPR.type``. The type can be
file named in the format of ``PRnumber.type``. The type can be
one of ``feature``, ``bugfix``, ``removal`` (also used for
deprecations), or ``misc`` (for internal-only changes). The content of
the file is your changelog entry, which can contain RestructuredText
formatting. A note of contributors is welcomed in changelogs for
non-misc changes (the content of misc changes is not displayed).
For example, a fix for a bug reported in #1234 would have its
changelog entry in ``changelog.d/1234.bugfix``, and contain content
like "The security levels of Florbs are now validated when
recieved over federation. Contributed by Jane Matrix".
For example, a fix in PR #1234 would have its changelog entry in
``changelog.d/1234.bugfix``, and contain content like "The security levels of
Florbs are now validated when recieved over federation. Contributed by Jane
Matrix".
Attribution
~~~~~~~~~~~

1
changelog.d/3184.feature Normal file
View File

@ -0,0 +1 @@
Add /_media/r0/config

1
changelog.d/3568.feature Normal file
View File

@ -0,0 +1 @@
speed up /members API and add `at` and `membership` params as per MSC1227

1
changelog.d/3574.feature Normal file
View File

@ -0,0 +1 @@
implement `summary` block in /sync response as per MSC688

1
changelog.d/3589.feature Normal file
View File

@ -0,0 +1 @@
Add lazy-loading support to /messages as per MSC1227

1
changelog.d/3689.bugfix Normal file
View File

@ -0,0 +1 @@
Fix mau blocking calulation bug on login

1
changelog.d/3705.bugfix Normal file
View File

@ -0,0 +1 @@
Support more federation endpoints on workers

View File

@ -32,6 +32,7 @@ from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
@ -54,6 +55,7 @@ logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedAccountDataStore,
SlavedProfileStore,
SlavedApplicationServiceStore,
SlavedPusherStore,

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
yield self.auth.check_auth_blocking()
yield self.auth.check_auth_blocking(user_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,6 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self.auth.check_auth_blocking()
auth_api = self.hs.get_auth()
user_id = None
try:
@ -743,6 +742,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", True, user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id)
defer.returnValue(user_id)
@defer.inlineCallbacks

View File

@ -25,7 +25,13 @@ from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
NotFoundError,
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
@ -36,6 +42,7 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -82,26 +89,83 @@ class MessageHandler(object):
defer.returnValue(data)
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
def get_state_events(
self, user_id, room_id, types=None, filtered_types=None,
at_token=None, is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
left the room return the state events from when they left. If an explicit
'at' parameter is passed, return the state events as of that event, if
visible.
Args:
user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from.
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
is_guest(bool): whether this user is a guest
Returns:
A list of dicts representing state events. [{}, {}, {}]
Raises:
NotFoundError (404) if the at token does not yield an event
AuthError (403) if the user doesn't have permission to view
members of this room.
"""
membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
room_id, user_id
if at_token:
# FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1,
)
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token, ))
visible_events = yield filter_events_for_client(
self.store, user_id, last_events,
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
[event.event_id], types, filtered_types=filtered_types,
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s" % (
user_id, room_id, at_token,
)
)
else:
membership, membership_event_id = (
yield self.auth.check_in_room_or_world_readable(
room_id, user_id,
)
)
if membership == Membership.JOIN:
room_state = yield self.state.get_current_state(room_id)
state_ids = yield self.store.get_filtered_current_state_ids(
room_id, types, filtered_types=filtered_types,
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], None
[membership_event_id], types, filtered_types=filtered_types,
)
room_state = room_state[membership_event_id]

View File

@ -18,7 +18,7 @@ import logging
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from synapse.types import RoomStreamToken
@ -251,6 +251,33 @@ class PaginationHandler(object):
is_peeking=(member_event_id is None),
)
state = None
if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members
types = [
(EventTypes.Member, state_key)
for state_key in set(
event.sender # FIXME: we also care about invite targets etc.
for event in events
)
]
state_ids = yield self.store.get_state_ids_for_event(
events[0].event_id, types=types,
)
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
if state:
state = yield filter_events_for_client(
self.store,
user_id,
state.values(),
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec()
chunk = {
@ -262,4 +289,10 @@ class PaginationHandler(object):
"end": next_token.to_string(),
}
if state:
chunk["state"] = [
serialize_event(e, time_now, as_client_event)
for e in state
]
defer.returnValue(chunk)

View File

@ -77,6 +77,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"ephemeral",
"account_data",
"unread_notifications",
"summary",
])):
__slots__ = []
@ -507,10 +508,142 @@ class SyncHandler(object):
state = {}
defer.returnValue(state)
@defer.inlineCallbacks
def compute_summary(self, room_id, sync_config, batch, state, now_token):
""" Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
Args:
room_id(str):
sync_config(synapse.handlers.sync.SyncConfig):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
state(dict): dict of (type, state_key) -> Event as returned by
compute_state_delta
now_token(str): Token of the end of the current batch.
Returns:
A deferred dict describing the room summary
"""
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
room_id, end_token=now_token.room_key, limit=1,
)
if not last_events:
defer.returnValue(None)
return
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id, [
(EventTypes.Member, None),
(EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''),
]
)
member_ids = {
state_key: event_id
for (t, state_key), event_id in state_ids.iteritems()
if t == EventTypes.Member
}
name_id = state_ids.get((EventTypes.Name, ''))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
summary = {}
# FIXME: it feels very heavy to load up every single membership event
# just to calculate the counts.
member_events = yield self.store.get_events(member_ids.values())
joined_user_ids = []
invited_user_ids = []
for ev in member_events.values():
if ev.content.get("membership") == Membership.JOIN:
joined_user_ids.append(ev.state_key)
elif ev.content.get("membership") == Membership.INVITE:
invited_user_ids.append(ev.state_key)
# TODO: only send these when they change.
summary["m.joined_member_count"] = len(joined_user_ids)
summary["m.invited_member_count"] = len(invited_user_ids)
if name_id or canonical_alias_id:
defer.returnValue(summary)
# FIXME: order by stream ordering, not alphabetic
me = sync_config.user.to_string()
if (joined_user_ids or invited_user_ids):
summary['m.heroes'] = sorted(
[
user_id
for user_id in (joined_user_ids + invited_user_ids)
if user_id != me
]
)[0:5]
else:
summary['m.heroes'] = sorted(
[user_id for user_id in member_ids.keys() if user_id != me]
)[0:5]
if not sync_config.filter_collection.lazy_load_members():
defer.returnValue(summary)
# ensure we send membership events for heroes if needed
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.get_lazy_loaded_members_cache(cache_key)
# track which members the client should already know about via LL:
# Ones which are already in state...
existing_members = set(
user_id for (typ, user_id) in state.keys()
if typ == EventTypes.Member
)
# ...or ones which are in the timeline...
for ev in batch.events:
if ev.type == EventTypes.Member:
existing_members.add(ev.state_key)
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
member_ids[hero_id]
for hero_id in summary['m.heroes']
if (
cache.get(hero_id) != member_ids[hero_id] and
hero_id not in existing_members
)
]
missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
defer.returnValue(summary)
def get_lazy_loaded_members_cache(self, cache_key):
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)
return cache
@defer.inlineCallbacks
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
full_state):
""" Works out the differnce in state between the start of the timeline
""" Works out the difference in state between the start of the timeline
and the previous sync.
Args:
@ -524,7 +657,7 @@ class SyncHandler(object):
full_state(bool): Whether to force returning the full state.
Returns:
A deferred new event dictionary
A deferred dict of (type, state_key) -> Event
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@ -622,13 +755,7 @@ class SyncHandler(object):
if lazy_load_members and not include_redundant_members:
cache_key = (sync_config.user.to_string(), sync_config.device_id)
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
self.lazy_loaded_members_cache[cache_key] = cache
else:
logger.debug("found LruCache for %r", cache_key)
cache = self.get_lazy_loaded_members_cache(cache_key)
# if it's a new sync sequence, then assume the client has had
# amnesia and doesn't want any recent lazy-loaded members
@ -1429,7 +1556,6 @@ class SyncHandler(object):
if events == [] and tags is None:
return
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@ -1472,6 +1598,18 @@ class SyncHandler(object):
full_state=full_state
)
summary = {}
if (
sync_config.filter_collection.lazy_load_members() and
(
any(ev.type == EventTypes.Member for ev in batch.events) or
since_token is None
)
):
summary = yield self.compute_summary(
room_id, sync_config, batch, state, now_token
)
if room_builder.rtype == "joined":
unread_notifications = {}
room_sync = JoinedSyncResult(
@ -1481,6 +1619,7 @@ class SyncHandler(object):
ephemeral=ephemeral,
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
)
if room_sync or always_include:

View File

@ -34,7 +34,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request)
events = yield self.message_handler.get_state_events(
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
# for consistency with /messages etc.
# useful for getting the membership in retrospect as of a given /sync
# response.
at_token_string = parse_string(request, "at")
if at_token_string is None:
at_token = None
else:
at_token = StreamToken.from_string(at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
# instead, except filters aren't currently aware of memberships.
# See https://github.com/matrix-org/matrix-doc/issues/1337 for more details.
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
events = yield handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
types=[(EventTypes.Member, None)],
)
chunk = []
for event in events:
if event["type"] != EventTypes.Member:
if (
(membership and event['content'].get("membership") != membership) or
(not_membership and event['content'].get("membership") == not_membership)
):
continue
chunk.append(event)
@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
}))
# deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")

View File

@ -370,6 +370,7 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
return result

View File

@ -27,11 +27,22 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": [
# XXX: at some point we need to decide whether we need to include
# the previous version numbers, given we've defined r0.3.0 to be
# backwards compatible with r0.2.0. But need to check how
# conscientious we've been in compatibility, and decide whether the
# middle number is the major revision when at 0.X.Y (as opposed to
# X.Y.Z). And we need to decide whether it's fair to make clients
# parse the version string to figure out what's going on.
"r0.0.1",
"r0.1.0",
"r0.2.0",
"r0.3.0",
]
],
# as per MSC1497:
"unstable_features": {
"m.lazy_load_members": True,
}
})

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import respond_with_json, wrap_json_request_handler
class MediaConfigResource(Resource):
isLeaf = True
def __init__(self, hs):
Resource.__init__(self)
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {
"m.upload.size": config.max_upload_size,
}
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
yield self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict)
def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET

View File

@ -42,6 +42,7 @@ from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import is_ascii, random_string
from ._base import FileInfo, respond_404, respond_with_responder
from .config_resource import MediaConfigResource
from .download_resource import DownloadResource
from .filepath import MediaFilePaths
from .identicon_resource import IdenticonResource
@ -754,7 +755,6 @@ class MediaRepositoryResource(Resource):
Resource.__init__(self)
media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(
@ -765,3 +765,4 @@ class MediaRepositoryResource(Resource):
self.putChild("preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))
self.putChild("config", MediaConfigResource(hs))

View File

@ -1150,17 +1150,16 @@ class SQLBaseStore(object):
defer.returnValue(retval)
def get_user_count_txn(self, txn):
"""Get a total number of registerd users in the users list.
"""Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
defer.Deferred: resolves to int
int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
count = txn.fetchone()[0]
defer.returnValue(count)
return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):

View File

@ -1911,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
max_depth = max(row[0] for row in rows)
if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
raise SynapseError(

View File

@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
_get_current_state_ids_txn,
)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
room_id (str)
types (list[(Str, (Str|None))]): List of (type, state_key) tuples
which are used to filter the state fetched. `state_key` may be
None, which matches any `state_key`
filtered_types (list[Str]|None): List of types to apply the above filter to.
Returns:
deferred: dict of (type, state_key) -> event
"""
include_other_types = False if filtered_types is None else True
def _get_filtered_current_state_ids_txn(txn):
results = {}
sql = """SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ? %s"""
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
(etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
)
for etype, state_key in types
]
if include_other_types:
unique_types = set(filtered_types)
clause_to_args.append(
(
"AND type <> ? " * len(unique_types),
list(unique_types)
)
)
else:
# If types is None we fetch all the state, and so just use an
# empty where clause with no extra args.
clause_to_args = [("", [])]
for where_clause, where_args in clause_to_args:
args = [room_id]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
results[key] = event_id
return results
return self.runInteraction(
"get_filtered_current_state_ids",
_get_filtered_current_state_ids_txn,
)
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
deferred: A list of dicts corresponding to the event_ids given.
The dicts are mappings from (type, state_key) -> state_events
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
@ -418,7 +480,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
"""
Get the state dicts corresponding to a list of events
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
@ -431,7 +494,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events.
Returns:
A deferred dict from event_id -> (type, state_key) -> state_event
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,

View File

@ -124,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
def test_mau_limits_exceeded(self):
def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
@ -141,6 +141,42 @@ class AuthTestCase(unittest.TestCase):
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True

View File

@ -98,7 +98,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users)
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
yield self.handler.get_or_create_user("@user:server", 'c', "User")
@ -112,6 +112,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks
def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
@ -121,6 +127,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks
def test_register_saml2_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
@ -129,3 +141,9 @@ class RegistrationTestCase(unittest.TestCase):
)
with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part")

View File

@ -148,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
)
# check we can use filter_types to grab a specific room member
# check we can use filtered_types to grab a specific room member
# without filtering out the other event types
state = yield self.store.get_state_for_event(
e5.event_id,