Use Pydantic to validate /devices endpoints (#14054)

This commit is contained in:
David Robertson 2022-10-07 13:54:07 +01:00 committed by GitHub
parent 1fa2e58772
commit 2295095c97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 47 deletions

View File

@ -0,0 +1 @@
Improve validation of request bodies for the [Device Management](https://spec.matrix.org/v1.4/client-server-api/#device-management) and [MSC2697 Device Dehyrdation](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) client-server API endpoints.

View File

@ -14,18 +14,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from pydantic import Extra, StrictStr
from synapse.api import errors from synapse.api import errors
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, parse_and_validate_json_object_from_request,
parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData]
devices: List[StrictStr]
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PostBody)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# DELETE # TODO: Can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.PostBody.parse_obj({})
else: else:
raise e raise e
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
"remove device(s) from your account", "remove device(s) from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet):
) )
await self.device_handler.delete_devices( await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"] requester.user.to_string(), body.devices
) )
return 200, {} return 200, {}
@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet):
return 200, device return 200, device
class DeleteBody(RequestBodyModel):
auth: Optional[AuthenticationData]
@interactive_auth_handler @interactive_auth_handler
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.DeleteBody)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.DeleteBody.parse_obj({})
else: else:
raise raise
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
"remove a device from your account", "remove a device from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet):
) )
return 200, {} return 200, {}
class PutBody(RequestBodyModel):
display_name: Optional[StrictStr]
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.device_handler.update_device( await self.device_handler.update_device(
requester.user.to_string(), device_id, body requester.user.to_string(), device_id, body.dict()
) )
return 200, {} return 200, {}
class DehydratedDeviceDataModel(RequestBodyModel):
"""JSON blob describing a dehydrated device to be stored.
Expects other freeform fields. Use .dict() to access them.
"""
class Config:
extra = Extra.allow
algorithm: StrictStr
class DehydratedDeviceServlet(RestServlet): class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device. """Retrieve or store a dehydrated device.
@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet):
else: else:
raise errors.NotFoundError("No dehydrated device available") raise errors.NotFoundError("No dehydrated device available")
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: class PutBody(RequestBodyModel):
submission = parse_json_object_from_request(request) device_id: StrictStr
requester = await self.auth.get_user_by_req(request) device_data: DehydratedDeviceDataModel
initial_device_display_name: Optional[StrictStr]
if "device_data" not in submission: async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
raise errors.SynapseError( submission = parse_and_validate_json_object_from_request(request, self.PutBody)
400, requester = await self.auth.get_user_by_req(request)
"device_data missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
400,
"device_data must be an object",
errcode=errors.Codes.INVALID_PARAM,
)
device_id = await self.device_handler.store_dehydrated_device( device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(), requester.user.to_string(),
submission["device_data"], submission.device_data,
submission.get("initial_device_display_name", None), submission.initial_device_display_name,
) )
return 200, {"device_id": device_id} return 200, {"device_id": device_id}
@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
class PostBody(RequestBodyModel):
device_id: StrictStr
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request) submission = parse_and_validate_json_object_from_request(request, self.PostBody)
if "device_id" not in submission:
raise errors.SynapseError(
400,
"device_id missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
400,
"device_id must be a string",
errcode=errors.Codes.INVALID_PARAM,
)
result = await self.device_handler.rehydrate_device( result = await self.device_handler.rehydrate_device(
requester.user.to_string(), requester.user.to_string(),
self.auth.get_access_token_from_request(request), self.auth.get_access_token_from_request(request),
submission["device_id"], submission.device_id,
) )
return 200, result return 200, result