Fix additional type hints from Twisted 21.2.0. (#9591)
This commit is contained in:
parent
1e67bff833
commit
55da8df078
|
@ -0,0 +1 @@
|
|||
Fix incorrect type hints.
|
|
@ -164,7 +164,7 @@ class Auth:
|
|||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
rights: str = "access",
|
||||
allow_expired: bool = False,
|
||||
|
|
|
@ -880,7 +880,9 @@ class FederationHandlerRegistry:
|
|||
self.edu_handlers = (
|
||||
{}
|
||||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
||||
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
||||
self.query_handlers = (
|
||||
{}
|
||||
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
|
||||
|
||||
# Map from type to instance names that we should route EDU handling to.
|
||||
# We randomly choose one instance from the list to route to for each new
|
||||
|
@ -914,7 +916,7 @@ class FederationHandlerRegistry:
|
|||
self.edu_handlers[edu_type] = handler
|
||||
|
||||
def register_query_handler(
|
||||
self, query_type: str, handler: Callable[[dict], defer.Deferred]
|
||||
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
):
|
||||
"""Sets the handler callable that will be used to handle an incoming
|
||||
federation query of the given type.
|
||||
|
@ -987,7 +989,7 @@ class FederationHandlerRegistry:
|
|||
# Oh well, let's just log and move on.
|
||||
logger.warning("No handler registered for EDU type %s", edu_type)
|
||||
|
||||
async def on_query(self, query_type: str, args: dict):
|
||||
async def on_query(self, query_type: str, args: dict) -> JsonDict:
|
||||
handler = self.query_handlers.get(query_type)
|
||||
if handler:
|
||||
return await handler(args)
|
||||
|
|
|
@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.oidc_config import (
|
||||
|
@ -538,7 +539,7 @@ class OidcProvider:
|
|||
"""
|
||||
metadata = await self.load_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
headers = {
|
||||
raw_headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": self._http_client.user_agent,
|
||||
"Accept": "application/json",
|
||||
|
@ -552,10 +553,10 @@ class OidcProvider:
|
|||
body = urlencode(args, True)
|
||||
|
||||
# Fill the body/headers with credentials
|
||||
uri, headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=token_endpoint, headers=headers, body=body
|
||||
uri, raw_headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=token_endpoint, headers=raw_headers, body=body
|
||||
)
|
||||
headers = {k: [v] for (k, v) in headers.items()}
|
||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
|
|
|
@ -57,7 +57,13 @@ from twisted.web.client import (
|
|||
)
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
||||
from twisted.web.iweb import (
|
||||
UNKNOWN_LENGTH,
|
||||
IAgent,
|
||||
IBodyProducer,
|
||||
IPolicyForHTTPS,
|
||||
IResponse,
|
||||
)
|
||||
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
||||
|
@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
|
|||
return query_str.encode("utf8")
|
||||
|
||||
|
||||
@implementer(IPolicyForHTTPS)
|
||||
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||
"""
|
||||
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
||||
|
|
|
@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
|
|||
TCP4ClientEndpoint,
|
||||
TCP6ClientEndpoint,
|
||||
)
|
||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.internet.tcp import Connection
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -52,7 +53,9 @@ class LogProducer:
|
|||
format: A callable to format the log record to a string.
|
||||
"""
|
||||
|
||||
transport = attr.ib(type=ITransport)
|
||||
# This is essentially ITCPTransport, but that is missing certain fields
|
||||
# (connected and registerProducer) which are part of the implementation.
|
||||
transport = attr.ib(type=Connection)
|
||||
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||
_buffer = attr.ib(type=deque)
|
||||
_paused = attr.ib(default=False, type=bool, init=False)
|
||||
|
@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
|
|||
if self._connection_waiter:
|
||||
return
|
||||
|
||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
||||
|
||||
def fail(failure: Failure) -> None:
|
||||
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||
# reconnect (this will cause an infinite loop of errors).
|
||||
|
@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
|
|||
self._connect()
|
||||
|
||||
def writer(result: Protocol) -> None:
|
||||
# Force recognising transport as a Connection and not the more
|
||||
# generic ITransport.
|
||||
transport = result.transport # type: Connection # type: ignore
|
||||
|
||||
# We have a connection. If we already have a producer, and its
|
||||
# transport is the same, just trigger a resumeProducing.
|
||||
if self._producer and result.transport is self._producer.transport:
|
||||
if self._producer and transport is self._producer.transport:
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
return
|
||||
|
@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
|
|||
# Make a new producer and start it.
|
||||
self._producer = LogProducer(
|
||||
buffer=self._buffer,
|
||||
transport=result.transport,
|
||||
transport=transport,
|
||||
format=self.format,
|
||||
)
|
||||
result.transport.registerProducer(self._producer, True)
|
||||
transport.registerProducer(self._producer, True)
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
|
||||
self._connection_waiter.addCallbacks(writer, fail)
|
||||
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
|
||||
deferred.addCallbacks(writer, fail)
|
||||
self._connection_waiter = deferred
|
||||
|
||||
def _handle_pressure(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from twisted.internet.base import DelayedCall
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
||||
|
@ -66,7 +66,7 @@ class EmailPusher(Pusher):
|
|||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.email = pusher_config.pushkey
|
||||
self.timed_call = None # type: Optional[DelayedCall]
|
||||
self.timed_call = None # type: Optional[IDelayedCall]
|
||||
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
||||
self._inited = False
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
|
|||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||
from synapse.replication.tcp.streams import (
|
||||
STREAMS_MAP,
|
||||
AccountDataStream,
|
||||
|
@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
|
|||
|
||||
# the type of the entries in _command_queues_by_stream
|
||||
_StreamCommandQueue = Deque[
|
||||
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
|
||||
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
|
||||
]
|
||||
|
||||
|
||||
|
@ -174,7 +174,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
# The currently connected connections. (The list of places we need to send
|
||||
# outgoing replication commands to.)
|
||||
self._connections = [] # type: List[AbstractConnection]
|
||||
self._connections = [] # type: List[IReplicationConnection]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
|
@ -197,7 +197,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
# For each connection, the incoming stream names that have received a POSITION
|
||||
# from that connection.
|
||||
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
|
||||
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_command_queue",
|
||||
|
@ -220,7 +220,7 @@ class ReplicationCommandHandler:
|
|||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
def _add_command_to_stream_queue(
|
||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
) -> None:
|
||||
"""Queue the given received command for processing
|
||||
|
||||
|
@ -267,7 +267,7 @@ class ReplicationCommandHandler:
|
|||
async def _process_command(
|
||||
self,
|
||||
cmd: Union[PositionCommand, RdataCommand],
|
||||
conn: AbstractConnection,
|
||||
conn: IReplicationConnection,
|
||||
stream_name: str,
|
||||
) -> None:
|
||||
if isinstance(cmd, PositionCommand):
|
||||
|
@ -321,10 +321,10 @@ class ReplicationCommandHandler:
|
|||
"""Get a list of streams that this instances replicates."""
|
||||
return self._streams_to_replicate
|
||||
|
||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
||||
def send_positions_to_connection(self, conn: IReplicationConnection):
|
||||
"""Send current position of all streams this process is source of to
|
||||
the connection.
|
||||
"""
|
||||
|
@ -347,7 +347,7 @@ class ReplicationCommandHandler:
|
|||
)
|
||||
|
||||
def on_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
||||
self, conn: IReplicationConnection, cmd: UserSyncCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_sync_counter.inc()
|
||||
|
||||
|
@ -359,21 +359,23 @@ class ReplicationCommandHandler:
|
|||
return None
|
||||
|
||||
def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
if self._is_master:
|
||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
||||
def on_FEDERATION_ACK(
|
||||
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||
):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||
|
||||
def on_USER_IP(
|
||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
||||
self, conn: IReplicationConnection, cmd: UserIpCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
|
@ -395,7 +397,7 @@ class ReplicationCommandHandler:
|
|||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
|
@ -412,7 +414,7 @@ class ReplicationCommandHandler:
|
|||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_rdata(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
||||
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
|
||||
) -> None:
|
||||
"""Process an RDATA command
|
||||
|
||||
|
@ -486,7 +488,7 @@ class ReplicationCommandHandler:
|
|||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
|
@ -496,7 +498,7 @@ class ReplicationCommandHandler:
|
|||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_position(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
||||
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
|
||||
) -> None:
|
||||
"""Process a POSITION command
|
||||
|
||||
|
@ -553,7 +555,9 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
||||
def on_REMOTE_SERVER_UP(
|
||||
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
|
@ -576,7 +580,7 @@ class ReplicationCommandHandler:
|
|||
# between two instances, but that is not currently supported).
|
||||
self.send_command(cmd, ignore_conn=conn)
|
||||
|
||||
def new_connection(self, connection: AbstractConnection):
|
||||
def new_connection(self, connection: IReplicationConnection):
|
||||
"""Called when we have a new connection."""
|
||||
self._connections.append(connection)
|
||||
|
||||
|
@ -603,7 +607,7 @@ class ReplicationCommandHandler:
|
|||
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||
)
|
||||
|
||||
def lost_connection(self, connection: AbstractConnection):
|
||||
def lost_connection(self, connection: IReplicationConnection):
|
||||
"""Called when a connection is closed/lost."""
|
||||
# we no longer need _streams_by_connection for this connection.
|
||||
streams = self._streams_by_connection.pop(connection, None)
|
||||
|
@ -624,7 +628,7 @@ class ReplicationCommandHandler:
|
|||
return bool(self._connections)
|
||||
|
||||
def send_command(
|
||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
||||
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||
):
|
||||
"""Send a command to all connected connections.
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
|
|||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
|
@ -54,6 +53,7 @@ from inspect import isawaitable
|
|||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.internet import task
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
|
@ -121,6 +121,14 @@ class ConnectionStates:
|
|||
CLOSED = "closed"
|
||||
|
||||
|
||||
class IReplicationConnection(Interface):
|
||||
"""An interface for replication connections."""
|
||||
|
||||
def send_command(cmd: Command):
|
||||
"""Send the command down the connection"""
|
||||
|
||||
|
||||
@implementer(IReplicationConnection)
|
||||
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
"""Base replication protocol shared between client and server.
|
||||
|
||||
|
@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||
self.send_command(ReplicateCommand())
|
||||
|
||||
|
||||
class AbstractConnection(abc.ABC):
|
||||
"""An interface for replication connections."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send the command down the connection"""
|
||||
pass
|
||||
|
||||
|
||||
# This tells python that `BaseReplicationStreamProtocol` implements the
|
||||
# interface.
|
||||
AbstractConnection.register(BaseReplicationStreamProtocol)
|
||||
|
||||
|
||||
# The following simply registers metrics for the replication connections
|
||||
|
||||
pending_commands = LaterGauge(
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
|||
|
||||
import attr
|
||||
import txredisapi
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.interfaces import IAddress, IConnector
|
||||
|
@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import (
|
|||
parse_command_from_line,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractConnection,
|
||||
IReplicationConnection,
|
||||
tcp_inbound_commands_counter,
|
||||
tcp_outbound_commands_counter,
|
||||
)
|
||||
|
@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
|
|||
pass
|
||||
|
||||
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
@implementer(IReplicationConnection)
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
"""Connection to redis subscribed to replication stream.
|
||||
|
||||
This class fulfils two functions:
|
||||
|
@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
connection, parsing *incoming* messages into replication commands, and passing them
|
||||
to `ReplicationCommandHandler`
|
||||
|
||||
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
|
||||
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
|
||||
onto outbound_redis_connection.
|
||||
|
||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||
|
|
|
@ -15,10 +15,9 @@
|
|||
|
||||
import re
|
||||
|
||||
import twisted.web.server
|
||||
|
||||
import synapse.api.auth
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
|
@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
|
|||
return patterns
|
||||
|
||||
|
||||
async def assert_requester_is_admin(
|
||||
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
|
||||
) -> None:
|
||||
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
|
||||
"""Verify that the requester is an admin user
|
||||
|
||||
Args:
|
||||
auth: api.auth.Auth singleton
|
||||
auth: Auth singleton
|
||||
request: incoming request
|
||||
|
||||
Raises:
|
||||
|
@ -53,11 +50,11 @@ async def assert_requester_is_admin(
|
|||
await assert_user_is_admin(auth, requester.user)
|
||||
|
||||
|
||||
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
|
||||
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
|
||||
"""Verify that the given user is an admin user
|
||||
|
||||
Args:
|
||||
auth: api.auth.Auth singleton
|
||||
auth: Auth singleton
|
||||
user_id: user to check
|
||||
|
||||
Raises:
|
||||
|
|
|
@ -17,10 +17,9 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import (
|
||||
admin_patterns,
|
||||
assert_requester_is_admin,
|
||||
|
@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
|
@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
|||
self.media_repository = hs.get_media_repository()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
|
@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
|
|||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
|
@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
|
|||
self.server_name = hs.hostname
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, server_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
|
|
|
@ -32,6 +32,7 @@ from synapse.http.servlet import (
|
|||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import GroupID, JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
|
|||
return 200, group_description
|
||||
|
||||
@_validate_group_id
|
||||
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
group_id: str,
|
||||
category_id: Optional[str],
|
||||
room_id: str,
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
group_id: str,
|
||||
role_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
|||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str, config_key: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
|||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id, user_id
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id, user_id
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
|||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||
|
@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
|||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
|
|||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
|||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
|
|||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ from synapse.http.server import (
|
|||
respond_with_json_bytes,
|
||||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
|
@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
|
|
@ -22,6 +22,7 @@ from twisted.web.server import Request
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.media.v1.media_storage import SpamMediaException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
|
|||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_POST(self, request: Request) -> None:
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
|
|
|
@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||
|
||||
@cache_in_self
|
||||
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
||||
return (
|
||||
InsecureInterceptableContextFactory()
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
else RegularPolicyForHTTPS()
|
||||
)
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
|
||||
return InsecureInterceptableContextFactory()
|
||||
return RegularPolicyForHTTPS()
|
||||
|
||||
@cache_in_self
|
||||
def get_simple_http_client(self) -> SimpleHttpClient:
|
||||
|
|
|
@ -17,7 +17,7 @@ import mock
|
|||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.replication.tcp.commands import FederationAckCommand
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
|||
"""
|
||||
rch = self.hs.get_tcp_replication()
|
||||
|
||||
# wire up the ReplicationCommandHandler to a mock connection
|
||||
mock_connection = mock.Mock(spec=AbstractConnection)
|
||||
# wire up the ReplicationCommandHandler to a mock connection, which needs
|
||||
# to implement IReplicationConnection. (Note that Mock doesn't understand
|
||||
# interfaces, but casing an interface to a list gives the attributes.)
|
||||
mock_connection = mock.Mock(spec=list(IReplicationConnection))
|
||||
rch.new_connection(mock_connection)
|
||||
|
||||
# tell it it received an RDATA row
|
||||
|
|
Loading…
Reference in New Issue