Merge pull request #3689 from matrix-org/neilj/fix_off_by_1+maus

Fix Mau off by one errors
This commit is contained in:
Neil Johnson 2018-08-15 16:19:41 +00:00 committed by GitHub
commit 81d727efa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 4 deletions

1
changelog.d/3689.bugfix Normal file
View File

@ -0,0 +1 @@
Fix mau blocking calulation bug on login

View File

@ -520,7 +520,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking(user_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -734,7 +734,6 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
yield self.auth.check_auth_blocking()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -743,6 +742,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True, user_id)
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
yield self.auth.check_auth_blocking(user_id)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -124,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded(self): def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
@ -141,6 +141,42 @@ class AuthTestCase(unittest.TestCase):
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
@defer.inlineCallbacks
def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True

View File

@ -98,7 +98,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.store.count_monthly_users = Mock( self.store.count_monthly_users = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
yield self.handler.get_or_create_user("@user:server", 'c', "User") yield self.handler.get_or_create_user("@user:server", 'c', "User")
@ -112,6 +112,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register_mau_blocked(self): def test_register_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -121,6 +127,12 @@ class RegistrationTestCase(unittest.TestCase):
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register_saml2_mau_blocked(self): def test_register_saml2_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -129,3 +141,9 @@ class RegistrationTestCase(unittest.TestCase):
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part")