Merge branch 'develop' into babolivier/port_db_account_validity

This commit is contained in:
Brendan Abolivier 2019-06-04 09:13:42 +01:00
commit aeb2263320
51 changed files with 529 additions and 417 deletions

1
changelog.d/5226.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5307.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where a notary server would sometimes forget old keys.

1
changelog.d/5321.bugfix Normal file
View File

@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.

1
changelog.d/5328.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

View File

@ -20,9 +20,7 @@ class CallVisitor(ast.NodeVisitor):
else: else:
return return
if name == "client_path_patterns": if name == "client_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s) PATTERNS_V2.append(node.args[0].s)

View File

@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet): class PresenceStatusStubServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs) super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri self.main_uri = hs.config.worker_main_http_uri
@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import defaultdict
import six import six
from six import raise_from from six import raise_from
@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
json_object(dict): The JSON object to verify. json_object(dict): The JSON object to verify.
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care)
deferred(Deferred[str, str, nacl.signing.VerifyKey]): deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no a verify key has been fetched. The deferreds' callbacks are run with no
@ -82,7 +86,8 @@ class VerifyKeyRequest(object):
server_name = attr.ib() server_name = attr.ib()
key_ids = attr.ib() key_ids = attr.ib()
json_object = attr.ib() json_object = attr.ib()
deferred = attr.ib() minimum_valid_until_ts = attr.ib()
deferred = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError): class KeyLookupError(ValueError):
@ -90,14 +95,16 @@ class KeyLookupError(ValueError):
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._key_fetchers = ( if key_fetchers is None:
StoreKeyFetcher(hs), key_fetchers = (
PerspectivesKeyFetcher(hs), StoreKeyFetcher(hs),
ServerKeyFetcher(hs), PerspectivesKeyFetcher(hs),
) ServerKeyFetcher(hs),
)
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with # map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download # an ongoing key download; the Deferred completes once the download
@ -106,9 +113,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds. # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server
Args:
server_name (str): name of the server which must have signed this object
json_object (dict): object to be checked
validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time
return logcontext.make_deferred_yieldable( return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0] self.verify_json_objects_for_server((req,))[0]
) )
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
@ -116,10 +139,12 @@ class Keyring(object):
necessary. necessary.
Args: Args:
server_and_json (list): List of pairs of (server_name, json_object) server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.
Returns: Returns:
List<Deferred>: for each input pair, a deferred indicating success List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel server_name. The deferreds run their callbacks in the sentinel
logcontext. logcontext.
@ -128,12 +153,12 @@ class Keyring(object):
verify_requests = [] verify_requests = []
handle = preserve_fn(_handle_key_deferred) handle = preserve_fn(_handle_key_deferred)
def process(server_name, json_object): def process(server_name, json_object, validity_time):
"""Process an entry in the request list """Process an entry in the request list
Given a (server_name, json_object) pair from the request list, Given a (server_name, json_object, validity_time) triplet from the request
adds a key request to verify_requests, and returns a deferred which will list, adds a key request to verify_requests, and returns a deferred which
complete or fail (in the sentinel context) when verification completes. will complete or fail (in the sentinel context) when verification completes.
""" """
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
@ -148,7 +173,7 @@ class Keyring(object):
# add the key request to the queue, but don't start it off yet. # add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred() server_name, key_ids, json_object, validity_time
) )
verify_requests.append(verify_request) verify_requests.append(verify_request)
@ -160,8 +185,8 @@ class Keyring(object):
return handle(verify_request) return handle(verify_request)
results = [ results = [
process(server_name, json_object) process(server_name, json_object, validity_time)
for server_name, json_object in server_and_json for server_name, json_object, validity_time in server_and_json
] ]
if verify_requests: if verify_requests:
@ -298,8 +323,12 @@ class Keyring(object):
verify_request.deferred.errback( verify_request.deferred.errback(
SynapseError( SynapseError(
401, 401,
"No key for %s with id %s" "No key for %s with ids in %s (min_validity %i)"
% (verify_request.server_name, verify_request.key_ids), % (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
) )
@ -323,18 +352,28 @@ class Keyring(object):
Args: Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests. remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be reomved from the list. Any successfully-completed requests will be removed from the list.
""" """
# dict[str, set(str)]: keys to fetch for each server # dict[str, dict[str, int]]: keys to fetch.
missing_keys = {} # server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)
for verify_request in remaining_requests: for verify_request in remaining_requests:
# any completed requests should already have been removed # any completed requests should already have been removed
assert not verify_request.deferred.called assert not verify_request.deferred.called
missing_keys.setdefault(verify_request.server_name, set()).update( keys_for_server = missing_keys[verify_request.server_name]
verify_request.key_ids
)
results = yield fetcher.get_keys(missing_keys.items()) for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts
)
results = yield fetcher.get_keys(missing_keys)
completed = list() completed = list()
for verify_request in remaining_requests: for verify_request in remaining_requests:
@ -344,25 +383,34 @@ class Keyring(object):
# complete this VerifyKeyRequest. # complete this VerifyKeyRequest.
result_keys = results.get(server_name, {}) result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids: for key_id in verify_request.key_ids:
key = result_keys.get(key_id) fetch_key_result = result_keys.get(key_id)
if key: if not fetch_key_result:
with PreserveLoggingContext(): # we didn't get a result for this key
verify_request.deferred.callback( continue
(server_name, key_id, key.verify_key)
) if (
completed.append(verify_request) fetch_key_result.valid_until_ts
break < verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, fetch_key_result.verify_key)
)
completed.append(verify_request)
break
remaining_requests.difference_update(completed) remaining_requests.difference_update(completed)
class KeyFetcher(object): class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
Note that the iterables may be iterated more than once.
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@ -378,13 +426,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, key_ids in server_name_and_key_ids for server_name, keys_for_server in keys_to_fetch.items()
for key_id in key_ids for key_id in keys_for_server.keys()
) )
res = yield self.store.get_server_verify_keys(keys_to_fetch) res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {} keys = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
@ -399,7 +449,7 @@ class BaseV2KeyFetcher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response( def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[] self, from_server, response_json, time_added_ms
): ):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
@ -422,10 +472,6 @@ class BaseV2KeyFetcher(object):
time_added_ms (int): the timestamp to record in server_keys_json time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
""" """
@ -481,11 +527,6 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable( yield logcontext.make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
@ -498,7 +539,7 @@ class BaseV2KeyFetcher(object):
ts_expires_ms=ts_valid_until_ms, ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
for key_id in updated_key_ids for key_id in verify_keys
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -517,14 +558,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys keys_to_fetch, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except KeyLookupError as e: except KeyLookupError as e:
@ -558,13 +599,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, keys_to_fetch, perspective_name, perspective_keys
): ):
""" """
Args: Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
perspective_name (str): name of the notary server to query for the keys perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server notary server
@ -578,12 +621,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
""" """
logger.info( logger.info(
"Requesting keys %s from notary server %s", "Requesting keys %s from notary server %s",
server_names_and_key_ids, keys_to_fetch.items(),
perspective_name, perspective_name,
) )
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
try: try:
query_response = yield self.client.post_json( query_response = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
@ -591,9 +632,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={ data={
u"server_keys": { u"server_keys": {
server_name: { server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
} }
for server_name, key_ids in server_names_and_key_ids for server_name, server_keys in keys_to_fetch.items()
} }
}, },
long_retries=True, long_retries=True,
@ -703,15 +745,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client() self.client = hs.get_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
# TODO make this more resilient
results = yield logcontext.make_deferred_yieldable( results = yield logcontext.make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids self.get_server_verify_key_v2_direct,
server_name,
server_keys.keys(),
) )
for server_name, key_ids in server_name_and_key_ids for server_name, server_keys in keys_to_fetch.items()
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -730,6 +775,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys: if requested_key_id in keys:
continue continue
@ -754,7 +800,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
requested_ids=[requested_key_id],
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )

View File

@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json) (p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender for p in pdus_to_check_sender
]) ])
@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json) (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id for p in pdus_to_check_event_id
]) ])

View File

@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -102,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = { json_request = {
"method": request.method.decode('ascii'), "method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'), "uri": request.uri.decode('ascii'),
@ -138,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin

View File

@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give # TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while. # us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec(): now = self.clock.time_msec()
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(server_name, attestation) yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id): def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default """Create an attestation for the group_id and user_id with default

View File

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import logging
import re
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
# This subclass was presumably created to allow the auth for the v1
# protocol version to be different, however this behaviour was removed.
# it may no longer be necessary
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)

View File

@ -19,11 +19,10 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs) super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientDirectoryListServer(ClientV1RestServlet): class ClientDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$") PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet): class ClientAppserviceDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs) super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id): def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -19,21 +19,22 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet): class EventStreamRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events$") PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs): def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs) super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()

View File

@ -15,19 +15,19 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs) super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
} }
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso" SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt" JWT_TYPE = "m.login.jwt"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__()
self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet): class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__() super(CasRedirectServlet, self).__init__()
@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request) finish_request(request)
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket") PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs) super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": parse_string(request, "ticket", required=True), "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
body = yield http_client.get_raw(uri, args) body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data

View File

@ -17,17 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet): class LogoutRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout$") PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__()
self._auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
@ -41,7 +42,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = self._auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( yield self._device_handler.delete_device(
@ -50,11 +51,11 @@ class LogoutRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet): class LogoutAllRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout/all$") PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()

View File

@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs) super(PresenceStatusRestServlet, self).__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -16,18 +16,19 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(RestServlet):
class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request, parse_string from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(RestServlet):
class PushRuleRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs): def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs) super(PushRuleRestServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker_app is not None

View File

@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.client.v2_alpha._base import client_patterns
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PushersRestServlet(ClientV1RestServlet): class PushersRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers$") PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersRestServlet, self).__init__(hs) super(PushersRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {} return 200, {}
class PushersSetRestServlet(ClientV1RestServlet): class PushersSetRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers/set$") PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs) super(PushersSetRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
""" """
To allow pusher to be delete by clicking a link (ie. GET request) To allow pusher to be delete by clicking a link (ie. GET request)
""" """
PATTERNS = client_path_patterns("/pushers/remove$") PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):

View File

@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class TransactionRestServlet(RestServlet):
def __init__(self, hs):
super(TransactionRestServlet, self).__init__()
self.txns = HttpTransactionCache(hs)
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs): def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs) super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/rooms(?:/.*)?$"), client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/createRoom(?:/.*)?$"), client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_GET) self.on_GET)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_PUT) self.on_PUT)
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key) self.on_GET_no_state_key)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key) self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs) super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet): class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_path_patterns("/publicRooms$") PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
super(PublicRoomListRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs) super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join? # deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format # except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet): class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs) super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs) super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -491,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs) super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -511,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs) super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -530,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomEventServlet(ClientV1RestServlet): class RoomEventServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs) super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -554,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet):
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet): class RoomEventContextServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContextServlet, self).__init__(hs) super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -609,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs) super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -639,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs) super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
@ -722,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
) )
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -757,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
) )
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs) super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
@ -798,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet): class SearchRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns("/search$", v1=True)
"/search$"
)
def __init__(self, hs): def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs) super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -823,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class JoinedRoomsRestServlet(ClientV1RestServlet): class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_path_patterns("/joined_rooms$") PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomsRestServlet, self).__init__(hs) super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -853,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
""" """
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_path_patterns(regex_string + "$"), client_patterns(regex_string + "$", v1=True),
servlet.on_POST servlet.on_POST
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT servlet.on_PUT
) )
if with_get: if with_get:
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET servlet.on_GET
) )

View File

@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
class VoipRestServlet(ClientV1RestServlet): class VoipRestServlet(RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$") PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
super(VoipRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,), def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,),
if unstable: if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable" unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex)) patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases: for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))

View File

@ -30,13 +30,13 @@ from synapse.http.servlet import (
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__() super(EmailPasswordRequestTokenRestServlet, self).__init__()
@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet): class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__() super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_patterns("/account/password$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__() super(DeactivateAccountRestServlet, self).__init__()
@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$") PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__() super(ThreepidDeleteRestServlet, self).__init__()
@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$") PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs): def __init__(self, hs):
super(WhoamiRestServlet, self).__init__() super(WhoamiRestServlet, self).__init__()

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)" "/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet): class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/renew$") PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs): def __init__(self, hs):
@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/send_mail$") PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet): class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server.""" """End point to expose the capabilities of the server."""
PATTERNS = client_v2_patterns("/capabilities$") PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$") PATTERNS = client_patterns("/devices$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
""" """
PATTERNS = client_v2_patterns("/delete_devices") PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs): def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__() super(DeleteDevicesRestServlet, self).__init__()
@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$") PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()

View File

@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID from synapse.types import GroupID
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet): class GroupServlet(RestServlet):
"""Get the group profile """Get the group profile
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs): def __init__(self, hs):
super(GroupServlet, self).__init__() super(GroupServlet, self).__init__()
@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet): class GroupSummaryServlet(RestServlet):
"""Get the full group summary """Get the full group summary
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSummaryServlet, self).__init__() super(GroupSummaryServlet, self).__init__()
@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id - /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?" "(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)$"
@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet): class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category """Get/add/update/delete a group category
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
) )
@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet): class GroupCategoriesServlet(RestServlet):
"""Get all group categories """Get all group categories
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$" "/groups/(?P<group_id>[^/]*)/categories/$"
) )
@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet): class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role """Get/add/update/delete a group role
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
) )
@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet): class GroupRolesServlet(RestServlet):
"""Get all group roles """Get all group roles
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$" "/groups/(?P<group_id>[^/]*)/roles/$"
) )
@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id - /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?" "(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$" "/users/(?P<user_id>[^/]*)$"
@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet): class GroupRoomServlet(RestServlet):
"""Get all rooms in a group """Get all rooms in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs): def __init__(self, hs):
super(GroupRoomServlet, self).__init__() super(GroupRoomServlet, self).__init__()
@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet): class GroupUsersServlet(RestServlet):
"""Get all users in a group """Get all users in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupUsersServlet, self).__init__() super(GroupUsersServlet, self).__init__()
@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group """Get users invited to a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__() super(GroupInvitedUsersServlet, self).__init__()
@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy """Set group join policy
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__() super(GroupSettingJoinPolicyServlet, self).__init__()
@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet): class GroupCreateServlet(RestServlet):
"""Create a group """Create a group
""" """
PATTERNS = client_v2_patterns("/create_group$") PATTERNS = client_patterns("/create_group$")
def __init__(self, hs): def __init__(self, hs):
super(GroupCreateServlet, self).__init__() super(GroupCreateServlet, self).__init__()
@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group """Add a room to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
) )
@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group """Update the config of a room in a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$" "/config/(?P<config_key>[^/]*)$"
) )
@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group """Invite a user to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
) )
@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group """Kick a user from the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
) )
@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group """Leave a joined group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$" "/groups/(?P<group_id>[^/]*)/self/leave$"
) )
@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet): class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock """Attempt to join a group, or knock
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$" "/groups/(?P<group_id>[^/]*)/self/join$"
) )
@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite """Accept a group invite
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$" "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
) )
@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group """Update whether we publicise a users membership of a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$" "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
) )
@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$" "/publicised_groups/(?P<user_id>[^/]*)$"
) )
@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups$" "/publicised_groups$"
) )
@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet): class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to """Get all groups the logged in user is joined to
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/joined_groups$" "/joined_groups$"
) )

View File

@ -26,7 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/query$") PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK 200 OK
{ "changed": ["@foo:example.com"] } { "changed": ["@foo:example.com"] }
""" """
PATTERNS = client_v2_patterns("/keys/changes$") PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/claim$") PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet): class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$") PATTERNS = client_patterns("/notifications$")
def __init__(self, hs): def __init__(self, hs):
super(NotificationsServlet, self).__init__() super(NotificationsServlet, self).__init__()

View File

@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600, "expires_in": 3600,
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token" "/user/(?P<user_id>[^/]*)/openid/request_token"
) )

View File

@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet): class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs): def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__() super(ReadMarkerRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet): class ReceiptRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"

View File

@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't # We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison # exist. It's a _really minor_ security flaw to use plain string comparison
@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/available") PATTERNS = client_patterns("/register/available")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_patterns("/register$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,12 +66,12 @@ class RelationSendServlet(RestServlet):
def register(self, http_server): def register(self, http_server):
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_v2_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST, self.on_PUT_or_POST,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT, self.on_PUT,
) )
@ -120,7 +120,7 @@ class RelationPaginationServlet(RestServlet):
filtered by relation type and event type. filtered by relation type and event type.
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -197,7 +197,7 @@ class RelationAggregationPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -269,7 +269,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(), releases=(),

View File

@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet): class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
) )

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet): class RoomKeysServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
) )
@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version$" "/room_keys/version$"
) )
@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$" "/room_keys/version(/(?P<version>[^/]+))?$"
) )

View File

@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
# /rooms/$roomid/upgrade # /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$", "/rooms/(?P<room_id>[^/]*)/upgrade$",
) )

View File

@ -21,13 +21,13 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet): class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
) )

View File

@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns("/sync$") PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs): def __init__(self, hs):

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
""" """
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
) )
@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )

View File

@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols") PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__() super(ThirdPartyProtocolsServlet, self).__init__()
@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__() super(ThirdPartyProtocolServlet, self).__init__()
@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__() super(ThirdPartyUserServlet, self).__init__()
@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__() super(ThirdPartyLocationServlet, self).__init__()

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh Exchanges refresh tokens for a pair of an access token and a new refresh
token. token.
""" """
PATTERNS = client_v2_patterns("/tokenrefresh") PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet): class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$") PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
# attempt to do the patch before we load any synapse code # attempt to do the patch before we load any synapse code
tests.patch_inline_callbacks.do_patch() tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10 util.DEFAULT_TIMEOUT_DURATION = 20

View File

@ -19,6 +19,7 @@ from mock import Mock
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import signedjson.sign import signedjson.sign
from signedjson.key import get_verify_key
from twisted.internet import defer from twisted.internet import defer
@ -137,7 +138,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11" context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), ("server11", {})] [("server10", json1, 0), ("server11", {}, 0)]
) )
# the unsigned json should be rejected pretty quickly # the unsigned json should be rejected pretty quickly
@ -174,7 +175,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred() self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)] [("server10", json1, 0)]
) )
res_deferreds_2[0].addBoth(self.check_context, None) res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@ -197,31 +198,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
key1_id = "%s:%s" % (key1.alg, key1.version)
r = self.hs.datastore.store_server_verify_keys( r = self.hs.datastore.store_server_verify_keys(
"server9", "server9",
time.time() * 1000, time.time() * 1000,
[ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
(
"server9",
key1_id,
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
) )
self.get_success(r) self.get_success(r)
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server9", key1) signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object # should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}) d = _verify_json_for_server(kr, "server9", {}, 0)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
d = _verify_json_for_server(kr, "server9", json1) # should suceed on a signed object
self.assertFalse(d.called) d = _verify_json_for_server(kr, "server9", json1, 500)
# self.assertFalse(d.called)
self.get_success(d) self.get_success(d)
def test_verify_json_dedupes_key_requests(self):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
# the first request should succeed; the second should fail because the key
# has expired
results = kr.verify_json_objects_for_server(
[("server1", json1, 500), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to the fetcher
mock_fetcher.get_keys.assert_called_once()
def test_verify_json_falls_back_to_other_fetchers(self):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
}
}
)
def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
mock_fetcher2 = keyring.KeyFetcher()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
results = kr.verify_json_objects_for_server(
[("server1", json1, 1200), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to each fetcher
mock_fetcher1.get_keys.assert_called_once()
mock_fetcher2.get_keys.assert_called_once()
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -260,8 +338,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key, testverifykey)
@ -288,9 +366,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection # change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
self.get_failure( self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
fetcher.get_keys(server_name_and_key_ids), KeyLookupError
)
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@ -342,8 +418,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys) self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@ -401,7 +477,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response): def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs) fetcher = PerspectivesKeyFetcher(self.hs)
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
def post_json(destination, path, data, **kwargs): def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(destination, self.mock_perspective_server.server_name)
@ -410,9 +486,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
return self.get_success( return self.get_success(fetcher.get_keys(keys_to_fetch))
fetcher.get_keys(server_name_and_key_ids)
)
# start with a valid response so we can check we are testing the right thing # start with a valid response so we can check we are testing the right thing
response = build_response() response = build_response()
@ -435,6 +509,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
def get_key_id(key):
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
return "%s:%s" % (key.alg, key.version)
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx: with LoggingContext("testctx") as ctx:
@ -445,14 +524,16 @@ def run_in_context(f, *args, **kwargs):
defer.returnValue(rv) defer.returnValue(rv)
def _verify_json_for_server(keyring, server_name, json_object): def _verify_json_for_server(keyring, server_name, json_object, validity_time):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped """thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks. with the patched defer.inlineCallbacks.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def v(): def v():
rv1 = yield keyring.verify_json_for_server(server_name, json_object) rv1 = yield keyring.verify_json_for_server(
server_name, json_object, validity_time
)
defer.returnValue(rv1) defer.returnValue(rv1)
return run_in_context(v) return run_in_context(v)

View File

@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
users_in_room = self.get_success(self.store.get_users_in_room(room_id)) users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room) self.assertEqual([], users_in_room)
@unittest.DEBUG
def test_shutdown_room_block_peek(self): def test_shutdown_room_block_peek(self):
"""Test that a world_readable room can no longer be peeked into after """Test that a world_readable room can no longer be peeked into after
it has been shut down. it has been shut down.

View File

@ -30,7 +30,7 @@ from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test" myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/r0"
class MockHandlerProfileTestCase(unittest.TestCase): class MockHandlerProfileTestCase(unittest.TestCase):