Add a way to register the lower-case version of a conflicting MXID

This commit is contained in:
Eric Eastwood 2023-07-12 22:26:15 -05:00
parent 3bdb9b07fd
commit c1211e6dbe
3 changed files with 45 additions and 5 deletions

View File

@ -44,10 +44,18 @@ def request_registration(
shared_secret: str, shared_secret: str,
admin: bool = False, admin: bool = False,
user_type: Optional[str] = None, user_type: Optional[str] = None,
inhibit_user_in_use_error: bool = False,
_print: Callable[[str], None] = print, _print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit, exit: Callable[[int], None] = sys.exit,
) -> None: ) -> None:
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) qs_url_piece = ""
if inhibit_user_in_use_error:
qs_url_piece = "?inhibit_user_in_use_error=true"
url = "%s/_synapse/admin/v1/register%s" % (
server_location.rstrip("/"),
qs_url_piece,
)
# Get the nonce # Get the nonce
r = requests.get(url, verify=False) r = requests.get(url, verify=False)
@ -99,7 +107,8 @@ def request_registration(
pass pass
return exit(1) return exit(1)
_print("Success!") result = r.json()
_print("Success! -> %s" % result)
def register_new_user( def register_new_user(
@ -109,6 +118,7 @@ def register_new_user(
shared_secret: str, shared_secret: str,
admin: Optional[bool], admin: Optional[bool],
user_type: Optional[str], user_type: Optional[str],
inhibit_user_in_use_error: bool = False,
) -> None: ) -> None:
if not user: if not user:
try: try:
@ -148,7 +158,13 @@ def register_new_user(
admin = False admin = False
request_registration( request_registration(
user, password, server_location, shared_secret, bool(admin), user_type user,
password,
server_location,
shared_secret,
bool(admin),
user_type,
inhibit_user_in_use_error=inhibit_user_in_use_error,
) )
@ -179,6 +195,14 @@ def main() -> None:
default=None, default=None,
help="User type as specified in synapse.api.constants.UserTypes", help="User type as specified in synapse.api.constants.UserTypes",
) )
parser.add_argument(
"--inhibit_user_in_use_error",
default=False,
help="Whether to inhibit errors raised when registering a new account if the user ID already exists. "
"Useful when there is a collision with another MXID that has capital letters "
"but you want to register the same user with lower-case. "
"The registration will still fail if you try to register with the same MXID. Defaults to False",
)
admin_group = parser.add_mutually_exclusive_group() admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument( admin_group.add_argument(
"-a", "-a",
@ -264,7 +288,13 @@ def main() -> None:
admin = args.admin admin = args.admin
register_new_user( register_new_user(
args.user, args.password, server_url, secret, admin, args.user_type args.user,
args.password,
server_url,
secret,
admin,
args.user_type,
inhibit_user_in_use_error=args.inhibit_user_in_use_error,
) )

View File

@ -218,6 +218,7 @@ class RegistrationHandler:
user_agent_ips: Optional[List[Tuple[str, str]]] = None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
approved: bool = False, approved: bool = False,
inhibit_user_in_use_error: bool = False,
) -> str: ) -> str:
"""Registers a new client on the server. """Registers a new client on the server.
@ -283,7 +284,11 @@ class RegistrationHandler:
await self.auth_blocking.check_auth_blocking(threepid=threepid) await self.auth_blocking.check_auth_blocking(threepid=threepid)
if localpart is not None: if localpart is not None:
await self.check_username(localpart, guest_access_token=guest_access_token) await self.check_username(
localpart,
guest_access_token=guest_access_token,
inhibit_user_in_use_error=inhibit_user_in_use_error,
)
was_guest = guest_access_token is not None was_guest = guest_access_token is not None

View File

@ -520,6 +520,10 @@ class UserRegisterServlet(RestServlet):
HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled" HTTPStatus.BAD_REQUEST, "Shared secret registration is not enabled"
) )
inhibit_user_in_use_error = parse_boolean(
request, "inhibit_user_in_use_error", False
)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if "nonce" not in body: if "nonce" not in body:
@ -615,6 +619,7 @@ class UserRegisterServlet(RestServlet):
default_display_name=displayname, default_display_name=displayname,
by_admin=True, by_admin=True,
approved=True, approved=True,
inhibit_user_in_use_error=inhibit_user_in_use_error,
) )
result = await register._create_registration_details(user_id, body) result = await register._create_registration_details(user_id, body)