Fix additional type hints from Twisted 21.2.0. (#9591)
This commit is contained in:
parent
1e67bff833
commit
55da8df078
|
@ -0,0 +1 @@
|
||||||
|
Fix incorrect type hints.
|
|
@ -164,7 +164,7 @@ class Auth:
|
||||||
|
|
||||||
async def get_user_by_req(
|
async def get_user_by_req(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
rights: str = "access",
|
rights: str = "access",
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
|
|
|
@ -880,7 +880,9 @@ class FederationHandlerRegistry:
|
||||||
self.edu_handlers = (
|
self.edu_handlers = (
|
||||||
{}
|
{}
|
||||||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
||||||
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
self.query_handlers = (
|
||||||
|
{}
|
||||||
|
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
|
||||||
|
|
||||||
# Map from type to instance names that we should route EDU handling to.
|
# Map from type to instance names that we should route EDU handling to.
|
||||||
# We randomly choose one instance from the list to route to for each new
|
# We randomly choose one instance from the list to route to for each new
|
||||||
|
@ -914,7 +916,7 @@ class FederationHandlerRegistry:
|
||||||
self.edu_handlers[edu_type] = handler
|
self.edu_handlers[edu_type] = handler
|
||||||
|
|
||||||
def register_query_handler(
|
def register_query_handler(
|
||||||
self, query_type: str, handler: Callable[[dict], defer.Deferred]
|
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||||
):
|
):
|
||||||
"""Sets the handler callable that will be used to handle an incoming
|
"""Sets the handler callable that will be used to handle an incoming
|
||||||
federation query of the given type.
|
federation query of the given type.
|
||||||
|
@ -987,7 +989,7 @@ class FederationHandlerRegistry:
|
||||||
# Oh well, let's just log and move on.
|
# Oh well, let's just log and move on.
|
||||||
logger.warning("No handler registered for EDU type %s", edu_type)
|
logger.warning("No handler registered for EDU type %s", edu_type)
|
||||||
|
|
||||||
async def on_query(self, query_type: str, args: dict):
|
async def on_query(self, query_type: str, args: dict) -> JsonDict:
|
||||||
handler = self.query_handlers.get(query_type)
|
handler = self.query_handlers.get(query_type)
|
||||||
if handler:
|
if handler:
|
||||||
return await handler(args)
|
return await handler(args)
|
||||||
|
|
|
@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from twisted.web.client import readBody
|
from twisted.web.client import readBody
|
||||||
|
from twisted.web.http_headers import Headers
|
||||||
|
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.config.oidc_config import (
|
from synapse.config.oidc_config import (
|
||||||
|
@ -538,7 +539,7 @@ class OidcProvider:
|
||||||
"""
|
"""
|
||||||
metadata = await self.load_metadata()
|
metadata = await self.load_metadata()
|
||||||
token_endpoint = metadata.get("token_endpoint")
|
token_endpoint = metadata.get("token_endpoint")
|
||||||
headers = {
|
raw_headers = {
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
"User-Agent": self._http_client.user_agent,
|
"User-Agent": self._http_client.user_agent,
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
@ -552,10 +553,10 @@ class OidcProvider:
|
||||||
body = urlencode(args, True)
|
body = urlencode(args, True)
|
||||||
|
|
||||||
# Fill the body/headers with credentials
|
# Fill the body/headers with credentials
|
||||||
uri, headers, body = self._client_auth.prepare(
|
uri, raw_headers, body = self._client_auth.prepare(
|
||||||
method="POST", uri=token_endpoint, headers=headers, body=body
|
method="POST", uri=token_endpoint, headers=raw_headers, body=body
|
||||||
)
|
)
|
||||||
headers = {k: [v] for (k, v) in headers.items()}
|
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||||
|
|
||||||
# Do the actual request
|
# Do the actual request
|
||||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||||
|
|
|
@ -57,7 +57,13 @@ from twisted.web.client import (
|
||||||
)
|
)
|
||||||
from twisted.web.http import PotentialDataLoss
|
from twisted.web.http import PotentialDataLoss
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
from twisted.web.iweb import (
|
||||||
|
UNKNOWN_LENGTH,
|
||||||
|
IAgent,
|
||||||
|
IBodyProducer,
|
||||||
|
IPolicyForHTTPS,
|
||||||
|
IResponse,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
||||||
|
@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
|
||||||
return query_str.encode("utf8")
|
return query_str.encode("utf8")
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IPolicyForHTTPS)
|
||||||
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||||
"""
|
"""
|
||||||
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
||||||
|
|
|
@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
|
||||||
TCP4ClientEndpoint,
|
TCP4ClientEndpoint,
|
||||||
TCP6ClientEndpoint,
|
TCP6ClientEndpoint,
|
||||||
)
|
)
|
||||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
|
from twisted.internet.tcp import Connection
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -52,7 +53,9 @@ class LogProducer:
|
||||||
format: A callable to format the log record to a string.
|
format: A callable to format the log record to a string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transport = attr.ib(type=ITransport)
|
# This is essentially ITCPTransport, but that is missing certain fields
|
||||||
|
# (connected and registerProducer) which are part of the implementation.
|
||||||
|
transport = attr.ib(type=Connection)
|
||||||
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||||
_buffer = attr.ib(type=deque)
|
_buffer = attr.ib(type=deque)
|
||||||
_paused = attr.ib(default=False, type=bool, init=False)
|
_paused = attr.ib(default=False, type=bool, init=False)
|
||||||
|
@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
|
||||||
if self._connection_waiter:
|
if self._connection_waiter:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
|
||||||
|
|
||||||
def fail(failure: Failure) -> None:
|
def fail(failure: Failure) -> None:
|
||||||
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||||
# reconnect (this will cause an infinite loop of errors).
|
# reconnect (this will cause an infinite loop of errors).
|
||||||
|
@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
def writer(result: Protocol) -> None:
|
def writer(result: Protocol) -> None:
|
||||||
|
# Force recognising transport as a Connection and not the more
|
||||||
|
# generic ITransport.
|
||||||
|
transport = result.transport # type: Connection # type: ignore
|
||||||
|
|
||||||
# We have a connection. If we already have a producer, and its
|
# We have a connection. If we already have a producer, and its
|
||||||
# transport is the same, just trigger a resumeProducing.
|
# transport is the same, just trigger a resumeProducing.
|
||||||
if self._producer and result.transport is self._producer.transport:
|
if self._producer and transport is self._producer.transport:
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
return
|
return
|
||||||
|
@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
|
||||||
# Make a new producer and start it.
|
# Make a new producer and start it.
|
||||||
self._producer = LogProducer(
|
self._producer = LogProducer(
|
||||||
buffer=self._buffer,
|
buffer=self._buffer,
|
||||||
transport=result.transport,
|
transport=transport,
|
||||||
format=self.format,
|
format=self.format,
|
||||||
)
|
)
|
||||||
result.transport.registerProducer(self._producer, True)
|
transport.registerProducer(self._producer, True)
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
self._connection_waiter = None
|
self._connection_waiter = None
|
||||||
|
|
||||||
self._connection_waiter.addCallbacks(writer, fail)
|
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
|
||||||
|
deferred.addCallbacks(writer, fail)
|
||||||
|
self._connection_waiter = deferred
|
||||||
|
|
||||||
def _handle_pressure(self) -> None:
|
def _handle_pressure(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.internet.base import DelayedCall
|
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
from twisted.internet.interfaces import IDelayedCall
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
||||||
|
@ -66,7 +66,7 @@ class EmailPusher(Pusher):
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.email = pusher_config.pushkey
|
self.email = pusher_config.pushkey
|
||||||
self.timed_call = None # type: Optional[DelayedCall]
|
self.timed_call = None # type: Optional[IDelayedCall]
|
||||||
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
||||||
self._inited = False
|
self._inited = False
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
|
||||||
UserIpCommand,
|
UserIpCommand,
|
||||||
UserSyncCommand,
|
UserSyncCommand,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.protocol import AbstractConnection
|
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||||
from synapse.replication.tcp.streams import (
|
from synapse.replication.tcp.streams import (
|
||||||
STREAMS_MAP,
|
STREAMS_MAP,
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
|
@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
|
||||||
|
|
||||||
# the type of the entries in _command_queues_by_stream
|
# the type of the entries in _command_queues_by_stream
|
||||||
_StreamCommandQueue = Deque[
|
_StreamCommandQueue = Deque[
|
||||||
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
|
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# The currently connected connections. (The list of places we need to send
|
# The currently connected connections. (The list of places we need to send
|
||||||
# outgoing replication commands to.)
|
# outgoing replication commands to.)
|
||||||
self._connections = [] # type: List[AbstractConnection]
|
self._connections = [] # type: List[IReplicationConnection]
|
||||||
|
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_replication_tcp_resource_total_connections",
|
"synapse_replication_tcp_resource_total_connections",
|
||||||
|
@ -197,7 +197,7 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
# For each connection, the incoming stream names that have received a POSITION
|
# For each connection, the incoming stream names that have received a POSITION
|
||||||
# from that connection.
|
# from that connection.
|
||||||
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
|
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
|
||||||
|
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_replication_tcp_command_queue",
|
"synapse_replication_tcp_command_queue",
|
||||||
|
@ -220,7 +220,7 @@ class ReplicationCommandHandler:
|
||||||
self._server_notices_sender = hs.get_server_notices_sender()
|
self._server_notices_sender = hs.get_server_notices_sender()
|
||||||
|
|
||||||
def _add_command_to_stream_queue(
|
def _add_command_to_stream_queue(
|
||||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Queue the given received command for processing
|
"""Queue the given received command for processing
|
||||||
|
|
||||||
|
@ -267,7 +267,7 @@ class ReplicationCommandHandler:
|
||||||
async def _process_command(
|
async def _process_command(
|
||||||
self,
|
self,
|
||||||
cmd: Union[PositionCommand, RdataCommand],
|
cmd: Union[PositionCommand, RdataCommand],
|
||||||
conn: AbstractConnection,
|
conn: IReplicationConnection,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(cmd, PositionCommand):
|
if isinstance(cmd, PositionCommand):
|
||||||
|
@ -321,10 +321,10 @@ class ReplicationCommandHandler:
|
||||||
"""Get a list of streams that this instances replicates."""
|
"""Get a list of streams that this instances replicates."""
|
||||||
return self._streams_to_replicate
|
return self._streams_to_replicate
|
||||||
|
|
||||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
||||||
self.send_positions_to_connection(conn)
|
self.send_positions_to_connection(conn)
|
||||||
|
|
||||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
def send_positions_to_connection(self, conn: IReplicationConnection):
|
||||||
"""Send current position of all streams this process is source of to
|
"""Send current position of all streams this process is source of to
|
||||||
the connection.
|
the connection.
|
||||||
"""
|
"""
|
||||||
|
@ -347,7 +347,7 @@ class ReplicationCommandHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_USER_SYNC(
|
def on_USER_SYNC(
|
||||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
self, conn: IReplicationConnection, cmd: UserSyncCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
user_sync_counter.inc()
|
user_sync_counter.inc()
|
||||||
|
|
||||||
|
@ -359,21 +359,23 @@ class ReplicationCommandHandler:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_CLEAR_USER_SYNC(
|
def on_CLEAR_USER_SYNC(
|
||||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
if self._is_master:
|
if self._is_master:
|
||||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
def on_FEDERATION_ACK(
|
||||||
|
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||||
|
):
|
||||||
federation_ack_counter.inc()
|
federation_ack_counter.inc()
|
||||||
|
|
||||||
if self._federation_sender:
|
if self._federation_sender:
|
||||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||||
|
|
||||||
def on_USER_IP(
|
def on_USER_IP(
|
||||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
self, conn: IReplicationConnection, cmd: UserIpCommand
|
||||||
) -> Optional[Awaitable[None]]:
|
) -> Optional[Awaitable[None]]:
|
||||||
user_ip_cache_counter.inc()
|
user_ip_cache_counter.inc()
|
||||||
|
|
||||||
|
@ -395,7 +397,7 @@ class ReplicationCommandHandler:
|
||||||
assert self._server_notices_sender is not None
|
assert self._server_notices_sender is not None
|
||||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||||
|
|
||||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore RDATA that are just our own echoes
|
# Ignore RDATA that are just our own echoes
|
||||||
return
|
return
|
||||||
|
@ -412,7 +414,7 @@ class ReplicationCommandHandler:
|
||||||
self._add_command_to_stream_queue(conn, cmd)
|
self._add_command_to_stream_queue(conn, cmd)
|
||||||
|
|
||||||
async def _process_rdata(
|
async def _process_rdata(
|
||||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process an RDATA command
|
"""Process an RDATA command
|
||||||
|
|
||||||
|
@ -486,7 +488,7 @@ class ReplicationCommandHandler:
|
||||||
stream_name, instance_name, token, rows
|
stream_name, instance_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore POSITION that are just our own echoes
|
# Ignore POSITION that are just our own echoes
|
||||||
return
|
return
|
||||||
|
@ -496,7 +498,7 @@ class ReplicationCommandHandler:
|
||||||
self._add_command_to_stream_queue(conn, cmd)
|
self._add_command_to_stream_queue(conn, cmd)
|
||||||
|
|
||||||
async def _process_position(
|
async def _process_position(
|
||||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a POSITION command
|
"""Process a POSITION command
|
||||||
|
|
||||||
|
@ -553,7 +555,9 @@ class ReplicationCommandHandler:
|
||||||
|
|
||||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||||
|
|
||||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
def on_REMOTE_SERVER_UP(
|
||||||
|
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||||
|
):
|
||||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||||
|
|
||||||
|
@ -576,7 +580,7 @@ class ReplicationCommandHandler:
|
||||||
# between two instances, but that is not currently supported).
|
# between two instances, but that is not currently supported).
|
||||||
self.send_command(cmd, ignore_conn=conn)
|
self.send_command(cmd, ignore_conn=conn)
|
||||||
|
|
||||||
def new_connection(self, connection: AbstractConnection):
|
def new_connection(self, connection: IReplicationConnection):
|
||||||
"""Called when we have a new connection."""
|
"""Called when we have a new connection."""
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
|
||||||
|
@ -603,7 +607,7 @@ class ReplicationCommandHandler:
|
||||||
UserSyncCommand(self._instance_id, user_id, True, now)
|
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||||
)
|
)
|
||||||
|
|
||||||
def lost_connection(self, connection: AbstractConnection):
|
def lost_connection(self, connection: IReplicationConnection):
|
||||||
"""Called when a connection is closed/lost."""
|
"""Called when a connection is closed/lost."""
|
||||||
# we no longer need _streams_by_connection for this connection.
|
# we no longer need _streams_by_connection for this connection.
|
||||||
streams = self._streams_by_connection.pop(connection, None)
|
streams = self._streams_by_connection.pop(connection, None)
|
||||||
|
@ -624,7 +628,7 @@ class ReplicationCommandHandler:
|
||||||
return bool(self._connections)
|
return bool(self._connections)
|
||||||
|
|
||||||
def send_command(
|
def send_command(
|
||||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||||
):
|
):
|
||||||
"""Send a command to all connected connections.
|
"""Send a command to all connected connections.
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
|
||||||
> ERROR server stopping
|
> ERROR server stopping
|
||||||
* connection closed by server *
|
* connection closed by server *
|
||||||
"""
|
"""
|
||||||
import abc
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
@ -54,6 +53,7 @@ from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
from zope.interface import Interface, implementer
|
||||||
|
|
||||||
from twisted.internet import task
|
from twisted.internet import task
|
||||||
from twisted.protocols.basic import LineOnlyReceiver
|
from twisted.protocols.basic import LineOnlyReceiver
|
||||||
|
@ -121,6 +121,14 @@ class ConnectionStates:
|
||||||
CLOSED = "closed"
|
CLOSED = "closed"
|
||||||
|
|
||||||
|
|
||||||
|
class IReplicationConnection(Interface):
|
||||||
|
"""An interface for replication connections."""
|
||||||
|
|
||||||
|
def send_command(cmd: Command):
|
||||||
|
"""Send the command down the connection"""
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IReplicationConnection)
|
||||||
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||||
"""Base replication protocol shared between client and server.
|
"""Base replication protocol shared between client and server.
|
||||||
|
|
||||||
|
@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||||
self.send_command(ReplicateCommand())
|
self.send_command(ReplicateCommand())
|
||||||
|
|
||||||
|
|
||||||
class AbstractConnection(abc.ABC):
|
|
||||||
"""An interface for replication connections."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def send_command(self, cmd: Command):
|
|
||||||
"""Send the command down the connection"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# This tells python that `BaseReplicationStreamProtocol` implements the
|
|
||||||
# interface.
|
|
||||||
AbstractConnection.register(BaseReplicationStreamProtocol)
|
|
||||||
|
|
||||||
|
|
||||||
# The following simply registers metrics for the replication connections
|
# The following simply registers metrics for the replication connections
|
||||||
|
|
||||||
pending_commands = LaterGauge(
|
pending_commands = LaterGauge(
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import txredisapi
|
import txredisapi
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet.address import IPv4Address, IPv6Address
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.interfaces import IAddress, IConnector
|
from twisted.internet.interfaces import IAddress, IConnector
|
||||||
|
@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import (
|
||||||
parse_command_from_line,
|
parse_command_from_line,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.protocol import (
|
from synapse.replication.tcp.protocol import (
|
||||||
AbstractConnection,
|
IReplicationConnection,
|
||||||
tcp_inbound_commands_counter,
|
tcp_inbound_commands_counter,
|
||||||
tcp_outbound_commands_counter,
|
tcp_outbound_commands_counter,
|
||||||
)
|
)
|
||||||
|
@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
@implementer(IReplicationConnection)
|
||||||
|
class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||||
"""Connection to redis subscribed to replication stream.
|
"""Connection to redis subscribed to replication stream.
|
||||||
|
|
||||||
This class fulfils two functions:
|
This class fulfils two functions:
|
||||||
|
@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
connection, parsing *incoming* messages into replication commands, and passing them
|
connection, parsing *incoming* messages into replication commands, and passing them
|
||||||
to `ReplicationCommandHandler`
|
to `ReplicationCommandHandler`
|
||||||
|
|
||||||
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
|
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
|
||||||
onto outbound_redis_connection.
|
onto outbound_redis_connection.
|
||||||
|
|
||||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||||
|
|
|
@ -15,10 +15,9 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import twisted.web.server
|
from synapse.api.auth import Auth
|
||||||
|
|
||||||
import synapse.api.auth
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
|
||||||
return patterns
|
return patterns
|
||||||
|
|
||||||
|
|
||||||
async def assert_requester_is_admin(
|
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
|
||||||
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
|
|
||||||
) -> None:
|
|
||||||
"""Verify that the requester is an admin user
|
"""Verify that the requester is an admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth: api.auth.Auth singleton
|
auth: Auth singleton
|
||||||
request: incoming request
|
request: incoming request
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -53,11 +50,11 @@ async def assert_requester_is_admin(
|
||||||
await assert_user_is_admin(auth, requester.user)
|
await assert_user_is_admin(auth, requester.user)
|
||||||
|
|
||||||
|
|
||||||
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
|
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
|
||||||
"""Verify that the given user is an admin user
|
"""Verify that the given user is an admin user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth: api.auth.Auth singleton
|
auth: Auth singleton
|
||||||
user_id: user to check
|
user_id: user to check
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
|
|
@ -17,10 +17,9 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.admin._base import (
|
from synapse.rest.admin._base import (
|
||||||
admin_patterns,
|
admin_patterns,
|
||||||
assert_requester_is_admin,
|
assert_requester_is_admin,
|
||||||
|
@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(
|
async def on_POST(
|
||||||
self, request: Request, server_name: str, media_id: str
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, media_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
await assert_user_is_admin(self.auth, requester.user)
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
is_admin = await self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
|
@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, server_name: str, media_id: str
|
self, request: SynapseRequest, server_name: str, media_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.media_repository = hs.get_media_repository()
|
self.media_repository = hs.get_media_repository()
|
||||||
|
|
||||||
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, server_name: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
before_ts = parse_integer(request, "before_ts", required=True)
|
before_ts = parse_integer(request, "before_ts", required=True)
|
||||||
|
|
|
@ -32,6 +32,7 @@ from synapse.http.servlet import (
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import GroupID, JsonDict
|
from synapse.types import GroupID, JsonDict
|
||||||
|
|
||||||
from ._base import client_patterns
|
from ._base import client_patterns
|
||||||
|
@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
|
||||||
return 200, group_description
|
return 200, group_description
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
group_id: str,
|
||||||
|
category_id: Optional[str],
|
||||||
|
room_id: str,
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, category_id: str
|
self, request: SynapseRequest, group_id: str, category_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(
|
async def on_GET(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, role_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
group_id: str,
|
||||||
|
role_id: Optional[str],
|
||||||
|
user_id: str,
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_DELETE(
|
async def on_DELETE(
|
||||||
self, request: Request, group_id: str, room_id: str
|
self, request: SynapseRequest, group_id: str, room_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(
|
async def on_PUT(
|
||||||
self, request: Request, group_id: str, room_id: str, config_key: str
|
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
|
||||||
):
|
):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id, user_id
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id, user_id
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@_validate_group_id
|
@_validate_group_id
|
||||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, group_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||||
|
@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.groups_handler = hs.get_groups_local_handler()
|
self.groups_handler = hs.get_groups_local_handler()
|
||||||
|
|
||||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.app.homeserver import HomeServer
|
from synapse.app.homeserver import HomeServer
|
||||||
|
@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
await self.auth.get_user_by_req(request)
|
await self.auth.get_user_by_req(request)
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ from synapse.http.server import (
|
||||||
respond_with_json_bytes,
|
respond_with_json_bytes,
|
||||||
)
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||||
|
@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from twisted.web.server import Request
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.media.v1.media_storage import SpamMediaException
|
from synapse.rest.media.v1.media_storage import SpamMediaException
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
|
||||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_POST(self, request: Request) -> None:
|
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
|
|
|
@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
||||||
return (
|
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
|
||||||
InsecureInterceptableContextFactory()
|
return InsecureInterceptableContextFactory()
|
||||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
return RegularPolicyForHTTPS()
|
||||||
else RegularPolicyForHTTPS()
|
|
||||||
)
|
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_simple_http_client(self) -> SimpleHttpClient:
|
def get_simple_http_client(self) -> SimpleHttpClient:
|
||||||
|
|
|
@ -17,7 +17,7 @@ import mock
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import GenericWorkerServer
|
||||||
from synapse.replication.tcp.commands import FederationAckCommand
|
from synapse.replication.tcp.commands import FederationAckCommand
|
||||||
from synapse.replication.tcp.protocol import AbstractConnection
|
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||||
from synapse.replication.tcp.streams.federation import FederationStream
|
from synapse.replication.tcp.streams.federation import FederationStream
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
rch = self.hs.get_tcp_replication()
|
rch = self.hs.get_tcp_replication()
|
||||||
|
|
||||||
# wire up the ReplicationCommandHandler to a mock connection
|
# wire up the ReplicationCommandHandler to a mock connection, which needs
|
||||||
mock_connection = mock.Mock(spec=AbstractConnection)
|
# to implement IReplicationConnection. (Note that Mock doesn't understand
|
||||||
|
# interfaces, but casing an interface to a list gives the attributes.)
|
||||||
|
mock_connection = mock.Mock(spec=list(IReplicationConnection))
|
||||||
rch.new_connection(mock_connection)
|
rch.new_connection(mock_connection)
|
||||||
|
|
||||||
# tell it it received an RDATA row
|
# tell it it received an RDATA row
|
||||||
|
|
Loading…
Reference in New Issue