Merge branch 'develop' of github.com:matrix-org/synapse into neilj/server_notices_on_blocking

This commit is contained in:
Neil Johnson 2018-08-15 17:20:38 +01:00
commit 8cfad2e686
10 changed files with 232 additions and 24 deletions

1
changelog.d/3568.feature Normal file
View File

@ -0,0 +1 @@
speed up /members API and add `at` and `membership` params as per MSC1227

1
changelog.d/3689.bugfix Normal file
View File

@ -0,0 +1 @@
Fix mau blocking calulation bug on login

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking(user_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,6 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self.auth.check_auth_blocking()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -743,6 +742,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True, user_id)
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -25,7 +25,13 @@ from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
NotFoundError,
SynapseError,
)
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -36,6 +42,7 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -82,28 +89,85 @@ class MessageHandler(object):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False): def get_state_events(
self, user_id, room_id, types=None, filtered_types=None,
at_token=None, is_guest=False,
):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left. If an explicit
'at' parameter is passed, return the state events as of that event, if
visible.
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from. room_id(str): The room ID to get all state events from.
types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type.
May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table.
is_guest(bool): whether this user is a guest
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" Raises:
membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( NotFoundError (404) if the at token does not yield an event
room_id, user_id
)
if membership == Membership.JOIN: AuthError (403) if the user doesn't have permission to view
room_state = yield self.state.get_current_state(room_id) members of this room.
elif membership == Membership.LEAVE: """
room_state = yield self.store.get_state_for_events( if at_token:
[membership_event_id], None # FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=at_token.room_key, limit=1,
) )
room_state = room_state[membership_event_id]
if not last_events:
raise NotFoundError("Can't find event for token %s" % (at_token, ))
visible_events = yield filter_events_for_client(
self.store, user_id, last_events,
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
[event.event_id], types, filtered_types=filtered_types,
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s" % (
user_id, room_id, at_token,
)
)
else:
membership, membership_event_id = (
yield self.auth.check_in_room_or_world_readable(
room_id, user_id,
)
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
room_id, types, filtered_types=filtered_types,
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], types, filtered_types=filtered_types,
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(

View File

@ -34,7 +34,7 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -384,15 +384,39 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
events = yield self.message_handler.get_state_events( handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
# for consistency with /messages etc.
# useful for getting the membership in retrospect as of a given /sync
# response.
at_token_string = parse_string(request, "at")
if at_token_string is None:
at_token = None
else:
at_token = StreamToken.from_string(at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
# instead, except filters aren't currently aware of memberships.
# See https://github.com/matrix-org/matrix-doc/issues/1337 for more details.
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
at_token=at_token,
types=[(EventTypes.Member, None)],
) )
chunk = [] chunk = []
for event in events: for event in events:
if event["type"] != EventTypes.Member: if (
(membership and event['content'].get("membership") != membership) or
(not_membership and event['content'].get("membership") == not_membership)
):
continue continue
chunk.append(event) chunk.append(event)
@ -401,6 +425,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
})) }))
# deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet): class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")

View File

@ -1911,7 +1911,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
max_depth = max(row[0] for row in rows) max_depth = max(row[0] for row in rows)
if max_depth <= token.topological: if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase # We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not # otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties) # having any backwards extremeties)
raise SynapseError( raise SynapseError(

View File

@ -116,6 +116,69 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
_get_current_state_ids_txn, _get_current_state_ids_txn,
) )
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
room_id (str)
types (list[(Str, (Str|None))]): List of (type, state_key) tuples
which are used to filter the state fetched. `state_key` may be
None, which matches any `state_key`
filtered_types (list[Str]|None): List of types to apply the above filter to.
Returns:
deferred: dict of (type, state_key) -> event
"""
include_other_types = False if filtered_types is None else True
def _get_filtered_current_state_ids_txn(txn):
results = {}
sql = """SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ? %s"""
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
(etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
)
for etype, state_key in types
]
if include_other_types:
unique_types = set(filtered_types)
clause_to_args.append(
(
"AND type <> ? " * len(unique_types),
list(unique_types)
)
)
else:
# If types is None we fetch all the state, and so just use an
# empty where clause with no extra args.
clause_to_args = [("", [])]
for where_clause, where_args in clause_to_args:
args = [room_id]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
results[key] = event_id
return results
return self.runInteraction(
"get_filtered_current_state_ids",
_get_filtered_current_state_ids_txn,
)
@cached(max_entries=10000, iterable=True) @cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group): def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
@ -389,8 +452,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
If None, `types` filtering is applied to all events. If None, `types` filtering is applied to all events.
Returns: Returns:
deferred: A list of dicts corresponding to the event_ids given. deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
The dicts are mappings from (type, state_key) -> state_events
""" """
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(
event_ids, event_ids,

View File

@ -124,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded(self): def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
@ -141,6 +141,42 @@ class AuthTestCase(unittest.TestCase):
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@defer.inlineCallbacks
def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True

View File

@ -98,7 +98,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user("@user:server", 'c', "User") yield self.handler.get_or_create_user("@user:server", 'c', "User")
@ -112,6 +112,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register_mau_blocked(self): def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -121,6 +127,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register_saml2_mau_blocked(self): def test_register_saml2_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -129,3 +141,9 @@ class RegistrationTestCase(unittest.TestCase):
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part")

View File

@ -148,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
) )
# check we can use filter_types to grab a specific room member # check we can use filtered_types to grab a specific room member
# without filtering out the other event types # without filtering out the other event types
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, e5.event_id,