Validate input to POST /key/v2/query endpoint. (#16183)

To avoid 500 internal server errors with garbage input.
This commit is contained in:
Patrick Cloke 2023-08-25 14:10:31 -04:00 committed by GitHub
parent fcf7a5759e
commit 82699428e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 10 deletions

1
changelog.d/16183.misc Normal file
View File

@ -0,0 +1 @@
Improve error reporting of invalid data passed to `/_matrix/key/v2/query`.

View File

@ -16,6 +16,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from pydantic import Extra, StrictInt, StrictStr
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.web.server import Request from twisted.web.server import Request
@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_and_validate_json_object_from_request,
parse_integer, parse_integer,
parse_json_object_from_request,
) )
from synapse.rest.models import RequestBodyModel
from synapse.storage.keys import FetchKeyResultForRemote from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
@ -38,6 +40,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _KeyQueryCriteriaDataModel(RequestBodyModel):
class Config:
extra = Extra.allow
minimum_valid_until_ts: Optional[StrictInt]
class RemoteKey(RestServlet): class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported verification keys for a collection of servers. Checks that the reported
@ -96,6 +105,9 @@ class RemoteKey(RestServlet):
CATEGORY = "Federation requests" CATEGORY = "Federation requests"
class PostBody(RequestBodyModel):
server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
@ -137,24 +149,29 @@ class RemoteKey(RestServlet):
) )
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {} query = {
if minimum_valid_until_ts is not None: server: {
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts key_id: _KeyQueryCriteriaDataModel(
query = {server: {key_id: arguments}} minimum_valid_until_ts=minimum_valid_until_ts
)
}
}
else: else:
query = {server: {}} query = {server: {}}
return 200, await self.query_keys(query, query_remote_on_cache_miss=True) return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_and_validate_json_object_from_request(request, self.PostBody)
query = content["server_keys"] query = content.server_keys
return 200, await self.query_keys(query, query_remote_on_cache_miss=True) return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys( async def query_keys(
self, query: JsonDict, query_remote_on_cache_miss: bool = False self,
query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
query_remote_on_cache_miss: bool = False,
) -> JsonDict: ) -> JsonDict:
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
@ -196,8 +213,10 @@ class RemoteKey(RestServlet):
else: else:
ts_added_ms = key_result.added_ts ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {}) req_key = query.get(server_name, {}).get(
req_valid_until = req_key.get("minimum_valid_until_ts") key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
)
req_valid_until = req_key.minimum_valid_until_ts
if req_valid_until is not None: if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until: if ts_valid_until_ms < req_valid_until:
logger.debug( logger.debug(