Be stricter about JSON that is accepted by Synapse (#8106)

This commit is contained in:
Patrick Cloke 2020-08-19 07:26:03 -04:00 committed by GitHub
parent d89692ea84
commit eebf52be06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 85 additions and 62 deletions

1
changelog.d/8106.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where invalid JSON would be accepted by Synapse.

View File

@ -21,10 +21,10 @@ import typing
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from canonicaljson import json
from twisted.web import http from twisted.web import http
from synapse.util import json_decoder
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from synapse.types import JsonDict from synapse.types import JsonDict
@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but # try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text # default to M_UNKNOWN with the HTTP status as the error text
try: try:
j = json.loads(self.response.decode("utf-8")) j = json_decoder.decode(self.response.decode("utf-8"))
except ValueError: except ValueError:
j = {} j = {}

View File

@ -28,7 +28,6 @@ from typing import (
Union, Union,
) )
from canonicaljson import json
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
from twisted.internet import defer from twisted.internet import defer
@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet, ReplicationGetQueryRestServlet,
) )
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, unwrapFirstError from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -551,7 +550,7 @@ class FederationServer(FederationBase):
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_str in keys.items(): for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_str) key_id: json_decoder.decode(json_str)
} }
logger.info( logger.info(

View File

@ -15,8 +15,6 @@
import logging import logging
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from canonicaljson import json
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
tags, tags,
whitelisted_homeserver, whitelisted_homeserver,
) )
from synapse.util import json_decoder
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
if TYPE_CHECKING: if TYPE_CHECKING:
@ -71,7 +70,7 @@ class TransactionManager(object):
for edu in pending_edus: for edu in pending_edus:
context = edu.get_context() context = edu.get_context()
if context: if context:
span_contexts.append(extract_text_map(json.loads(context))) span_contexts.append(extract_text_map(json_decoder.decode(context)))
if keep_destination: if keep_destination:
edu.strip_context() edu.strip_context()

View File

@ -19,7 +19,7 @@ import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import attr import attr
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from signedjson.key import VerifyKey, decode_verify_key_bytes from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -35,7 +35,7 @@ from synapse.types import (
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
from synapse.util import unwrapFirstError from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -404,7 +404,7 @@ class E2eKeysHandler(object):
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items(): for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = { json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes) key_id: json_decoder.decode(json_bytes)
} }
@trace @trace
@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
def _one_time_keys_match(old_key_json, new_key): def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json) old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly # if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict): if not isinstance(old_key, dict) or not isinstance(new_key, dict):

View File

@ -21,8 +21,6 @@ import logging
import urllib.parse import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from twisted.internet.error import TimeoutError from twisted.internet.error import TimeoutError
from synapse.api.errors import ( from synapse.api.errors import (
@ -34,6 +32,7 @@ from synapse.api.errors import (
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.stringutils import assert_valid_client_secret, random_string
@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e: except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT? data = json_decoder.decode(e.msg) # XXX WAT?
return data return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from twisted.internet.interfaces import IDelayedCall from twisted.internet.interfaces import IDelayedCall
@ -55,6 +55,7 @@ from synapse.types import (
UserID, UserID,
create_requester, create_requester,
) )
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -864,7 +865,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db # Ensure that we can round trip before trying to persist in db
try: try:
dump = frozendict_json_encoder.encode(event.content) dump = frozendict_json_encoder.encode(event.content)
json.loads(dump) json_decoder.decode(dump)
except Exception: except Exception:
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode from urllib.parse import urlencode
@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -367,7 +367,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic # and check for an error field. If not, we respond with a generic
# error message. # error message.
try: try:
resp = json.loads(resp_body.decode("utf-8")) resp = json_decoder.decode(resp_body.decode("utf-8"))
error = resp["error"] error = resp["error"]
description = resp.get("error_description", error) description = resp.get("error_description", error)
except (ValueError, KeyError): except (ValueError, KeyError):
@ -384,7 +384,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will # Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not. # raise if not.
resp = json.loads(resp_body.decode("utf-8")) resp = json_decoder.decode(resp_body.decode("utf-8"))
if "error" in resp: if "error" in resp:
error = resp["error"] error = resp["error"]

View File

@ -16,13 +16,12 @@
import logging import logging
from typing import Any from typing import Any
from canonicaljson import json
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted is silly # Twisted is silly
data = pde.response data = pde.response
resp_body = json.loads(data.decode("utf-8")) resp_body = json_decoder.decode(data.decode("utf-8"))
if "success" in resp_body: if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly # Note that we do NOT check the hostname here: we explicitly

View File

@ -19,7 +19,7 @@ import urllib
from io import BytesIO from io import BytesIO
import treq import treq
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json
from netaddr import IPAddress from netaddr import IPAddress
from prometheus_client import Counter from prometheus_client import Counter
from zope.interface import implementer, provider from zope.interface import implementer, provider
@ -47,6 +47,7 @@ from synapse.http import (
from synapse.http.proxyagent import ProxyAgent from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -391,7 +392,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json_decoder.decode(body.decode("utf-8"))
else: else:
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body
@ -433,7 +434,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json_decoder.decode(body.decode("utf-8"))
else: else:
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body
@ -463,7 +464,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers) actual_headers.update(headers)
body = await self.get_raw(uri, args, headers=headers) body = await self.get_raw(uri, args, headers=headers)
return json.loads(body.decode("utf-8")) return json_decoder.decode(body.decode("utf-8"))
async def put_json(self, uri, json_body, args={}, headers=None): async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI. """ Puts some json to the given URI.
@ -506,7 +507,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response)) body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
return json.loads(body.decode("utf-8")) return json_decoder.decode(body.decode("utf-8"))
else: else:
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body response.code, response.phrase.decode("ascii", errors="replace"), body

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import random import random
import time import time
@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache from synapse.util.caches.ttlcache import TTLCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -181,7 +180,7 @@ class WellKnownResolver(object):
if response.code != 200: if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,)) raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json.loads(body.decode("utf-8")) parsed_body = json_decoder.decode(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body) logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii") result = parsed_body["m.server"].encode("ascii")

View File

@ -17,9 +17,8 @@
import logging import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None return None
try: try:
content = json.loads(content_bytes.decode("utf-8")) content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e: except Exception as e:
logger.warning("Unable to parse JSON: %s", e) logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@ -177,6 +177,7 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.util import json_decoder
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -499,7 +500,9 @@ def start_active_span_from_edu(
if opentracing is None: if opentracing is None:
return _noop_context_manager() return _noop_context_manager()
carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {}) carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
)
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
_references = [ _references = [
opentracing.child_of(span_context_from_string(x)) opentracing.child_of(span_context_from_string(x))
@ -699,7 +702,7 @@ def span_context_from_string(carrier):
Returns: Returns:
The active span context decoded from a string. The active span context decoded from a string.
""" """
carrier = json.loads(carrier) carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)

View File

@ -21,9 +21,7 @@ import abc
import logging import logging
from typing import Tuple, Type from typing import Tuple, Type
from canonicaljson import json from synapse.util import json_decoder, json_encoder
from synapse.util import json_encoder as _json_encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -125,7 +123,7 @@ class RdataCommand(Command):
stream_name, stream_name,
instance_name, instance_name,
None if token == "batch" else int(token), None if token == "batch" else int(token),
json.loads(row_json), json_decoder.decode(row_json),
) )
def to_line(self): def to_line(self):
@ -134,7 +132,7 @@ class RdataCommand(Command):
self.stream_name, self.stream_name,
self.instance_name, self.instance_name,
str(self.token) if self.token is not None else "batch", str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row), json_encoder.encode(self.row),
) )
) )
@ -359,7 +357,7 @@ class UserIpCommand(Command):
def from_line(cls, line): def from_line(cls, line):
user_id, jsn = line.split(" ", 1) user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen) return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@ -367,7 +365,7 @@ class UserIpCommand(Command):
return ( return (
self.user_id self.user_id
+ " " + " "
+ _json_encoder.encode( + json_encoder.encode(
( (
self.access_token, self.access_token,
self.ip, self.ip,

View File

@ -21,8 +21,6 @@ import re
from typing import List, Optional from typing import List, Optional
from urllib import parse as urlparse from urllib import parse as urlparse
from canonicaljson import json
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
MYPY = False MYPY = False
if MYPY: if MYPY:
@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8") filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
if ( if (
event_filter event_filter
and event_filter.filter_json.get("event_format", "client") and event_filter.filter_json.get("event_format", "client")
@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8") filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
else: else:
event_filter = None event_filter = None

View File

@ -16,8 +16,6 @@
import itertools import itertools
import logging import logging
from canonicaljson import json
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"): elif filter_id.startswith("{"):
try: try:
filter_object = json.loads(filter_id) filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit( set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit filter_object, self.hs.config.filter_timeline_limit
) )

View File

@ -15,19 +15,19 @@
import logging import logging
from typing import Dict, Set from typing import Dict, Set
from canonicaljson import json
from signedjson.sign import sign_json from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource): class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of that the NACL signature for the remote server is valid. Returns a dict of
@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
# Cast to bytes since postgresql returns a memoryview. # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"])) json_results.add(bytes(result["key_json"]))
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss: if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses) await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False) await self.query_keys(request, query, query_remote_on_cache_miss=False)
else: else:
signed_keys = [] signed_keys = []
for key_json in json_results: for key_json in json_results:
key_json = json.loads(key_json.decode("utf-8")) key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys: for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key) key_json = sign_json(key_json, self.config.server_name, signing_key)

View File

@ -19,12 +19,11 @@ import random
from abc import ABCMeta from abc import ABCMeta
from typing import Any, Optional from typing import Any, Optional
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util import json_decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,13 +98,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview): if isinstance(db_content, memoryview):
db_content = db_content.tobytes() db_content = db_content.tobytes()
# Decode it to a Unicode string before feeding it to json.loads, since # Decode it to a Unicode string before feeding it to the JSON decoder, since
# Python 3.5 does not support deserializing bytes. # Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)): if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8") db_content = db_content.decode("utf8")
try: try:
return json.loads(db_content) return json_decoder.decode(db_content)
except Exception: except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content) logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise raise

View File

@ -596,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason: if not allow_rejected and rejected_reason:
continue continue
d = db_to_json(row["json"]) # If the event or metadata cannot be parsed, log the error and act
internal_metadata = db_to_json(row["internal_metadata"]) # as if the event is unknown.
try:
d = db_to_json(row["json"])
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
internal_metadata = db_to_json(row["internal_metadata"])
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue
format_version = row["format_version"] format_version = row["format_version"]
if format_version is None: if format_version is None:

View File

@ -25,8 +25,18 @@ from synapse.logging import context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create a custom encoder to reduce the whitespace produced by JSON encoding.
json_encoder = json.JSONEncoder(separators=(",", ":")) def _reject_invalid_json(val):
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
# ensure that valid JSON is produced.
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure): def unwrapFirstError(failure):