A third batch of Pydantic validation for rest/client/account.py (#13736)
This commit is contained in:
parent
918c74bfb5
commit
742f9f9d78
|
@ -0,0 +1 @@
|
|||
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).
|
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import StrictBool, StrictStr, constr
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
|
@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests
|
|||
from synapse.push.mailer import Mailer
|
||||
from synapse.rest.client.models import (
|
||||
AuthenticationData,
|
||||
ClientSecretStr,
|
||||
EmailRequestTokenBody,
|
||||
MsisdnRequestTokenBody,
|
||||
)
|
||||
|
@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
auth: Optional[AuthenticationData] = None
|
||||
client_secret: ClientSecretStr
|
||||
sid: StrictStr
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
if not self.hs.config.registration.enable_3pid_changes:
|
||||
|
@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet):
|
|||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(body, ["client_secret", "sid"])
|
||||
sid = body["sid"]
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
body.dict(exclude_unset=True),
|
||||
"add a third-party identifier to your account",
|
||||
)
|
||||
|
||||
validation_session = await self.identity_handler.validate_threepid_session(
|
||||
client_secret, sid
|
||||
body.client_secret, body.sid
|
||||
)
|
||||
if validation_session:
|
||||
await self.auth_handler.add_threepid(
|
||||
|
@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet):
|
|||
self.identity_handler = hs.get_identity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
class PostBody(RequestBodyModel):
|
||||
client_secret: ClientSecretStr
|
||||
id_access_token: StrictStr
|
||||
id_server: StrictStr
|
||||
sid: StrictStr
|
||||
|
||||
assert_params_in_dict(
|
||||
body, ["id_server", "sid", "id_access_token", "client_secret"]
|
||||
)
|
||||
id_server = body["id_server"]
|
||||
sid = body["sid"]
|
||||
id_access_token = body["id_access_token"]
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
await self.identity_handler.bind_threepid(
|
||||
client_secret, sid, user_id, id_server, id_access_token
|
||||
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.datastore = self.hs.get_datastores().main
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
address: StrictStr
|
||||
id_server: Optional[StrictStr] = None
|
||||
medium: Literal["email", "msisdn"]
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
"""Unbind the given 3pid from a specific identity server, or identity servers that are
|
||||
known to have this 3pid bound
|
||||
"""
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["medium", "address"])
|
||||
|
||||
medium = body.get("medium")
|
||||
address = body.get("address")
|
||||
id_server = body.get("id_server")
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
|
||||
# unbind from all identity servers this threepid has been added to in the past
|
||||
result = await self.identity_handler.try_unbind_threepid(
|
||||
requester.user.to_string(),
|
||||
{"address": address, "medium": medium, "id_server": id_server},
|
||||
{
|
||||
"address": body.address,
|
||||
"medium": body.medium,
|
||||
"id_server": body.id_server,
|
||||
},
|
||||
)
|
||||
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
|
||||
|
||||
|
@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
address: StrictStr
|
||||
id_server: Optional[StrictStr] = None
|
||||
medium: Literal["email", "msisdn"]
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
if not self.hs.config.registration.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["medium", "address"])
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
try:
|
||||
ret = await self.auth_handler.delete_threepid(
|
||||
user_id, body["medium"], body["address"], body.get("id_server")
|
||||
user_id, body.medium, body.address, body.id_server
|
||||
)
|
||||
except Exception:
|
||||
# NB. This endpoint should succeed if there is nothing to
|
||||
|
|
|
@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel):
|
|||
type: Optional[StrictStr] = None
|
||||
|
||||
|
||||
class ThreePidRequestTokenBody(RequestBodyModel):
|
||||
if TYPE_CHECKING:
|
||||
client_secret: StrictStr
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
ClientSecretStr = StrictStr
|
||||
else:
|
||||
# See also assert_valid_client_secret()
|
||||
client_secret: constr(
|
||||
ClientSecretStr = constr(
|
||||
regex="[0-9a-zA-Z.=_-]", # noqa: F722
|
||||
min_length=0,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class ThreepidRequestTokenBody(RequestBodyModel):
|
||||
client_secret: ClientSecretStr
|
||||
id_server: Optional[StrictStr]
|
||||
id_access_token: Optional[StrictStr]
|
||||
next_link: Optional[StrictStr]
|
||||
|
@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel):
|
|||
return token
|
||||
|
||||
|
||||
class EmailRequestTokenBody(ThreePidRequestTokenBody):
|
||||
class EmailRequestTokenBody(ThreepidRequestTokenBody):
|
||||
email: StrictStr
|
||||
|
||||
# Canonicalise the email address. The addresses are all stored canonicalised
|
||||
|
@ -80,6 +82,6 @@ else:
|
|||
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
|
||||
|
||||
|
||||
class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
|
||||
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
|
||||
country: ISO3116_1_Alpha_2
|
||||
phone_number: StrictStr
|
||||
|
|
|
@ -11,14 +11,37 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import unittest as stdlib_unittest
|
||||
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.rest.client.models import EmailRequestTokenBody
|
||||
|
||||
|
||||
class EmailRequestTokenBodyTestCase(unittest.TestCase):
|
||||
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
|
||||
class Model(BaseModel):
|
||||
medium: Literal["email", "msisdn"]
|
||||
|
||||
def test_accepts_valid_medium_string(self) -> None:
|
||||
"""Sanity check that Pydantic behaves sensibly with an enum-of-str
|
||||
|
||||
This is arguably more of a test of a class that inherits from str and Enum
|
||||
simultaneously.
|
||||
"""
|
||||
model = self.Model.parse_obj({"medium": "email"})
|
||||
self.assertEqual(model.medium, "email")
|
||||
|
||||
def test_rejects_invalid_medium_value(self) -> None:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.Model.parse_obj({"medium": "interpretive_dance"})
|
||||
|
||||
def test_rejects_invalid_medium_type(self) -> None:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.Model.parse_obj({"medium": 123})
|
||||
|
||||
|
||||
class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
|
||||
base_request = {
|
||||
"client_secret": "hunter2",
|
||||
"email": "alice@wonderland.com",
|
||||
|
|
Loading…
Reference in New Issue