Correctly handle AS registerations and add test
This commit is contained in:
parent
35be260090
commit
4c33796b20
|
@ -709,6 +709,7 @@ class AuthHandler(BaseHandler):
|
||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
valid_until_ms: Optional[int],
|
valid_until_ms: Optional[int],
|
||||||
puppets_user_id: Optional[str] = None,
|
puppets_user_id: Optional[str] = None,
|
||||||
|
is_appservice_ghost: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a new access token for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
|
@ -725,6 +726,7 @@ class AuthHandler(BaseHandler):
|
||||||
we should always have a device ID)
|
we should always have a device ID)
|
||||||
valid_until_ms: when the token is valid until. None for
|
valid_until_ms: when the token is valid until. None for
|
||||||
no expiry.
|
no expiry.
|
||||||
|
is_appservice_ghost: Whether the user is an application ghost user
|
||||||
Returns:
|
Returns:
|
||||||
The access token for the user's session.
|
The access token for the user's session.
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -745,6 +747,10 @@ class AuthHandler(BaseHandler):
|
||||||
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
|
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_appservice_ghost
|
||||||
|
or self.hs.config.appservice.track_appservice_user_ips
|
||||||
|
):
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(user_id)
|
||||||
|
|
||||||
access_token = self.macaroon_gen.generate_access_token(user_id)
|
access_token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
|
|
|
@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
initial_display_name: Optional[str],
|
initial_display_name: Optional[str],
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
|
is_appservice_ghost: bool = False,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""Register a device for a user and generate an access token.
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
|
@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
initial_display_name=initial_display_name,
|
initial_display_name=initial_display_name,
|
||||||
is_guest=is_guest,
|
is_guest=is_guest,
|
||||||
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
return r["device_id"], r["access_token"]
|
return r["device_id"], r["access_token"]
|
||||||
|
|
||||||
|
@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
access_token = await self._auth_handler.get_access_token_for_user_id(
|
access_token = await self._auth_handler.get_access_token_for_user_id(
|
||||||
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
|
user_id,
|
||||||
|
device_id=registered_device_id,
|
||||||
|
valid_until_ms=valid_until_ms,
|
||||||
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (registered_device_id, access_token)
|
return (registered_device_id, access_token)
|
||||||
|
|
|
@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
async def _serialize_payload(
|
||||||
|
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
device_id (str|None): Device ID to use, if None a new one is
|
device_id (str|None): Device ID to use, if None a new one is
|
||||||
|
@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"initial_display_name": initial_display_name,
|
"initial_display_name": initial_display_name,
|
||||||
"is_guest": is_guest,
|
"is_guest": is_guest,
|
||||||
|
"is_appservice_ghost": is_appservice_ghost,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request(self, request, user_id):
|
||||||
|
@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
||||||
device_id = content["device_id"]
|
device_id = content["device_id"]
|
||||||
initial_display_name = content["initial_display_name"]
|
initial_display_name = content["initial_display_name"]
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
|
is_appservice_ghost = content["is_appservice_ghost"]
|
||||||
|
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
device_id, access_token = await self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name, is_guest
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
is_guest,
|
||||||
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {"device_id": device_id, "access_token": access_token}
|
return 200, {"device_id": device_id, "access_token": access_token}
|
||||||
|
|
|
@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet):
|
||||||
user_id = await self.registration_handler.appservice_register(
|
user_id = await self.registration_handler.appservice_register(
|
||||||
username, as_token
|
username, as_token
|
||||||
)
|
)
|
||||||
return await self._create_registration_details(user_id, body)
|
return await self._create_registration_details(
|
||||||
|
user_id, body, is_appservice_ghost=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def _create_registration_details(self, user_id, params):
|
async def _create_registration_details(
|
||||||
|
self, user_id, params, is_appservice_ghost=False
|
||||||
|
):
|
||||||
"""Complete registration of newly-registered user
|
"""Complete registration of newly-registered user
|
||||||
|
|
||||||
Allocates device_id if one was not given; also creates access_token.
|
Allocates device_id if one was not given; also creates access_token.
|
||||||
|
@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet):
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id, access_token = await self.registration_handler.register_device(
|
device_id, access_token = await self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name, is_guest=False
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
is_guest=False,
|
||||||
|
is_appservice_ghost=is_appservice_ghost,
|
||||||
)
|
)
|
||||||
|
|
||||||
result.update({"access_token": access_token, "device_id": device_id})
|
result.update({"access_token": access_token, "device_id": device_id})
|
||||||
|
|
|
@ -19,6 +19,7 @@ import json
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client.v2_alpha import register, sync
|
from synapse.rest.client.v2_alpha import register, sync
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -75,6 +76,44 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(e.code, 403)
|
self.assertEqual(e.code, 403)
|
||||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
|
def test_as_ignores_mau(self):
|
||||||
|
"""Test that application services can still create users when the MAU
|
||||||
|
limit has been reached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create and sync so that the MAU counts get updated
|
||||||
|
token1 = self.create_user("kermit1")
|
||||||
|
self.do_sync_for_user(token1)
|
||||||
|
token2 = self.create_user("kermit2")
|
||||||
|
self.do_sync_for_user(token2)
|
||||||
|
|
||||||
|
# check we're testing what we think we are: there should be two active users
|
||||||
|
self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
|
||||||
|
|
||||||
|
# We've created and activated two users, we shouldn't be able to
|
||||||
|
# register new users
|
||||||
|
with self.assertRaises(SynapseError) as cm:
|
||||||
|
self.create_user("kermit3")
|
||||||
|
|
||||||
|
e = cm.exception
|
||||||
|
self.assertEqual(e.code, 403)
|
||||||
|
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
|
# Cheekily add an application service that we use to register a new user
|
||||||
|
# with.
|
||||||
|
as_token = "foobartoken"
|
||||||
|
self.store.services_cache.append(
|
||||||
|
ApplicationService(
|
||||||
|
token=as_token,
|
||||||
|
hostname=self.hs.hostname,
|
||||||
|
id="SomeASID",
|
||||||
|
sender="@as_sender:test",
|
||||||
|
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.create_user("as_kermit4", token=as_token)
|
||||||
|
|
||||||
def test_allowed_after_a_month_mau(self):
|
def test_allowed_after_a_month_mau(self):
|
||||||
# Create and sync so that the MAU counts get updated
|
# Create and sync so that the MAU counts get updated
|
||||||
token1 = self.create_user("kermit1")
|
token1 = self.create_user("kermit1")
|
||||||
|
@ -192,7 +231,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
self.reactor.advance(100)
|
self.reactor.advance(100)
|
||||||
self.assertEqual(2, self.successResultOf(count))
|
self.assertEqual(2, self.successResultOf(count))
|
||||||
|
|
||||||
def create_user(self, localpart):
|
def create_user(self, localpart, token=None):
|
||||||
request_data = json.dumps(
|
request_data = json.dumps(
|
||||||
{
|
{
|
||||||
"username": localpart,
|
"username": localpart,
|
||||||
|
@ -201,7 +240,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
request, channel = self.make_request("POST", "/register", request_data)
|
request, channel = self.make_request(
|
||||||
|
"POST", "/register", request_data, access_token=token
|
||||||
|
)
|
||||||
|
|
||||||
if channel.code != 200:
|
if channel.code != 200:
|
||||||
raise HttpResponseException(
|
raise HttpResponseException(
|
||||||
|
|
Loading…
Reference in New Issue