Add + as an allowed character for Matrix IDs (MSC4009) (#15911)
This commit is contained in:
parent
92014fbf72
commit
a4243183f0
|
@ -0,0 +1 @@
|
||||||
|
Allow `+` in Matrix IDs, per [MSC4009](https://github.com/matrix-org/matrix-spec-proposals/pull/4009).
|
|
@ -382,9 +382,6 @@ class ExperimentalConfig(Config):
|
||||||
# Check that none of the other config options conflict with MSC3861 when enabled
|
# Check that none of the other config options conflict with MSC3861 when enabled
|
||||||
self.msc3861.check_config_conflicts(self.root)
|
self.msc3861.check_config_conflicts(self.root)
|
||||||
|
|
||||||
# MSC4009: E.164 Matrix IDs
|
|
||||||
self.msc4009_e164_mxids = experimental.get("msc4009_e164_mxids", False)
|
|
||||||
|
|
||||||
# MSC4010: Do not allow setting m.push_rules account data.
|
# MSC4010: Do not allow setting m.push_rules account data.
|
||||||
self.msc4010_push_rules_account_data = experimental.get(
|
self.msc4010_push_rules_account_data = experimental.get(
|
||||||
"msc4010_push_rules_account_data", False
|
"msc4010_push_rules_account_data", False
|
||||||
|
|
|
@ -143,15 +143,10 @@ class RegistrationHandler:
|
||||||
assigned_user_id: Optional[str] = None,
|
assigned_user_id: Optional[str] = None,
|
||||||
inhibit_user_in_use_error: bool = False,
|
inhibit_user_in_use_error: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if types.contains_invalid_mxid_characters(
|
if types.contains_invalid_mxid_characters(localpart):
|
||||||
localpart, self.hs.config.experimental.msc4009_e164_mxids
|
|
||||||
):
|
|
||||||
extra_chars = (
|
|
||||||
"=_-./+" if self.hs.config.experimental.msc4009_e164_mxids else "=_-./"
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
f"User ID can only contain characters a-z, 0-9, or '{extra_chars}'",
|
"User ID can only contain characters a-z, 0-9, or '=_-./+'",
|
||||||
Codes.INVALID_USERNAME,
|
Codes.INVALID_USERNAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -27,9 +27,9 @@ from synapse.http.servlet import parse_string
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
MXID_LOCALPART_ALLOWED_CHARACTERS,
|
||||||
UserID,
|
UserID,
|
||||||
map_username_to_mxid_localpart,
|
map_username_to_mxid_localpart,
|
||||||
mxid_localpart_allowed_characters,
|
|
||||||
)
|
)
|
||||||
from synapse.util.iterutils import chunk_seq
|
from synapse.util.iterutils import chunk_seq
|
||||||
|
|
||||||
|
@ -371,7 +371,7 @@ class SamlHandler:
|
||||||
|
|
||||||
|
|
||||||
DOT_REPLACE_PATTERN = re.compile(
|
DOT_REPLACE_PATTERN = re.compile(
|
||||||
"[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
|
"[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS)),)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -225,8 +225,6 @@ class SsoHandler:
|
||||||
|
|
||||||
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
||||||
|
|
||||||
self._e164_mxids = hs.config.experimental.msc4009_e164_mxids
|
|
||||||
|
|
||||||
def register_identity_provider(self, p: SsoIdentityProvider) -> None:
|
def register_identity_provider(self, p: SsoIdentityProvider) -> None:
|
||||||
p_id = p.idp_id
|
p_id = p.idp_id
|
||||||
assert p_id not in self._identity_providers
|
assert p_id not in self._identity_providers
|
||||||
|
@ -713,7 +711,7 @@ class SsoHandler:
|
||||||
# Since the localpart is provided via a potentially untrusted module,
|
# Since the localpart is provided via a potentially untrusted module,
|
||||||
# ensure the MXID is valid before registering.
|
# ensure the MXID is valid before registering.
|
||||||
if not attributes.localpart or contains_invalid_mxid_characters(
|
if not attributes.localpart or contains_invalid_mxid_characters(
|
||||||
attributes.localpart, self._e164_mxids
|
attributes.localpart
|
||||||
):
|
):
|
||||||
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
|
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
|
||||||
|
|
||||||
|
@ -946,7 +944,7 @@ class SsoHandler:
|
||||||
localpart,
|
localpart,
|
||||||
)
|
)
|
||||||
|
|
||||||
if contains_invalid_mxid_characters(localpart, self._e164_mxids):
|
if contains_invalid_mxid_characters(localpart):
|
||||||
raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
|
raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
|
||||||
user_id = UserID(localpart, self._server_name).to_string()
|
user_id = UserID(localpart, self._server_name).to_string()
|
||||||
user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
|
user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
|
|
@ -348,22 +348,15 @@ class EventID(DomainSpecificString):
|
||||||
SIGIL = "$"
|
SIGIL = "$"
|
||||||
|
|
||||||
|
|
||||||
mxid_localpart_allowed_characters = set(
|
MXID_LOCALPART_ALLOWED_CHARACTERS = set(
|
||||||
"_-./=" + string.ascii_lowercase + string.digits
|
"_-./=+" + string.ascii_lowercase + string.digits
|
||||||
)
|
)
|
||||||
# MSC4007 adds the + to the allowed characters.
|
|
||||||
#
|
|
||||||
# TODO If this was accepted, update the SSO code to support this, see the callers
|
|
||||||
# of map_username_to_mxid_localpart.
|
|
||||||
extended_mxid_localpart_allowed_characters = mxid_localpart_allowed_characters | {"+"}
|
|
||||||
|
|
||||||
# Guest user IDs are purely numeric.
|
# Guest user IDs are purely numeric.
|
||||||
GUEST_USER_ID_PATTERN = re.compile(r"^\d+$")
|
GUEST_USER_ID_PATTERN = re.compile(r"^\d+$")
|
||||||
|
|
||||||
|
|
||||||
def contains_invalid_mxid_characters(
|
def contains_invalid_mxid_characters(localpart: str) -> bool:
|
||||||
localpart: str, use_extended_character_set: bool
|
|
||||||
) -> bool:
|
|
||||||
"""Check for characters not allowed in an mxid or groupid localpart
|
"""Check for characters not allowed in an mxid or groupid localpart
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -374,12 +367,7 @@ def contains_invalid_mxid_characters(
|
||||||
Returns:
|
Returns:
|
||||||
True if there are any naughty characters
|
True if there are any naughty characters
|
||||||
"""
|
"""
|
||||||
allowed_characters = (
|
return any(c not in MXID_LOCALPART_ALLOWED_CHARACTERS for c in localpart)
|
||||||
extended_mxid_localpart_allowed_characters
|
|
||||||
if use_extended_character_set
|
|
||||||
else mxid_localpart_allowed_characters
|
|
||||||
)
|
|
||||||
return any(c not in allowed_characters for c in localpart)
|
|
||||||
|
|
||||||
|
|
||||||
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
|
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
|
||||||
|
@ -396,7 +384,7 @@ UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
|
||||||
# bytes rather than strings
|
# bytes rather than strings
|
||||||
#
|
#
|
||||||
NON_MXID_CHARACTER_PATTERN = re.compile(
|
NON_MXID_CHARACTER_PATTERN = re.compile(
|
||||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode(
|
("[^%s]" % (re.escape("".join(MXID_LOCALPART_ALLOWED_CHARACTERS - {"="})),)).encode(
|
||||||
"ascii"
|
"ascii"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -587,17 +587,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertFalse(self.get_success(d))
|
self.assertFalse(self.get_success(d))
|
||||||
|
|
||||||
def test_invalid_user_id(self) -> None:
|
def test_invalid_user_id(self) -> None:
|
||||||
invalid_user_id = "+abcd"
|
invalid_user_id = "^abcd"
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"experimental_features": {"msc4009_e164_mxids": True}})
|
def test_special_chars(self) -> None:
|
||||||
def text_extended_user_ids(self) -> None:
|
"""Ensure that characters which are allowed in Matrix IDs work."""
|
||||||
"""+ should be allowed according to MSC4009."""
|
valid_user_id = "a1234_-./=+"
|
||||||
valid_user_id = "+1234"
|
|
||||||
user_id = self.get_success(self.handler.register_user(localpart=valid_user_id))
|
user_id = self.get_success(self.handler.register_user(localpart=valid_user_id))
|
||||||
self.assertEqual(user_id, valid_user_id)
|
self.assertEqual(user_id, f"@{valid_user_id}:test")
|
||||||
|
|
||||||
def test_invalid_user_id_length(self) -> None:
|
def test_invalid_user_id_length(self) -> None:
|
||||||
invalid_user_id = "x" * 256
|
invalid_user_id = "x" * 256
|
||||||
|
|
Loading…
Reference in New Issue