Periodically send pings to detect dead Redis connections (#9218)
This is done by creating a custom `RedisFactory` subclass that periodically pings all connections in its pool. We also ensure that the `replyTimeout` param is non-null, so that we timeout waiting for the reply to those pings (and thus triggering a reconnect).
This commit is contained in:
parent
5b857b77f7
commit
a1ff1e967f
|
@ -0,0 +1 @@
|
|||
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
|
|
@ -19,8 +19,9 @@ from typing import List, Optional, Type, Union
|
|||
|
||||
class RedisProtocol:
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
async def ping(self) -> None: ...
|
||||
|
||||
class SubscriberProtocol:
|
||||
class SubscriberProtocol(RedisProtocol):
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
password: Optional[str]
|
||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||
|
@ -39,14 +40,13 @@ def lazyConnection(
|
|||
convertNumbers: bool = ...,
|
||||
) -> RedisProtocol: ...
|
||||
|
||||
class SubscriberFactory:
|
||||
def buildProtocol(self, addr): ...
|
||||
|
||||
class ConnectionHandler: ...
|
||||
|
||||
class RedisFactory:
|
||||
continueTrying: bool
|
||||
handler: RedisProtocol
|
||||
pool: List[RedisProtocol]
|
||||
replyTimeout: Optional[int]
|
||||
def __init__(
|
||||
self,
|
||||
uuid: str,
|
||||
|
@ -59,3 +59,7 @@ class RedisFactory:
|
|||
replyTimeout: Optional[int] = None,
|
||||
convertNumbers: Optional[int] = True,
|
||||
): ...
|
||||
def buildProtocol(self, addr) -> RedisProtocol: ...
|
||||
|
||||
class SubscriberFactory(RedisFactory):
|
||||
def __init__(self): ...
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
|
@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
|
|||
TypingStream,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -88,7 +92,7 @@ class ReplicationCommandHandler:
|
|||
back out to connections.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._replication_data_handler = hs.get_replication_data_handler()
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
self._store = hs.get_datastore()
|
||||
|
@ -300,7 +304,7 @@ class ReplicationCommandHandler:
|
|||
|
||||
# First create the connection for sending commands.
|
||||
outbound_redis_connection = lazyConnection(
|
||||
reactor=hs.get_reactor(),
|
||||
hs=hs,
|
||||
host=hs.config.redis_host,
|
||||
port=hs.config.redis_port,
|
||||
password=hs.config.redis.redis_password,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Type, cast
|
||||
|
||||
import txredisapi
|
||||
|
||||
|
@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
|
|||
from synapse.metrics.background_process_metrics import (
|
||||
BackgroundProcessLoggingContext,
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.replication.tcp.commands import (
|
||||
Command,
|
||||
|
@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
immediately after initialisation.
|
||||
|
||||
Attributes:
|
||||
handler: The command handler to handle incoming commands.
|
||||
stream_name: The *redis* stream name to subscribe to and publish from
|
||||
(not anything to do with Synapse replication streams).
|
||||
outbound_redis_connection: The connection to redis to use to send
|
||||
synapse_handler: The command handler to handle incoming commands.
|
||||
synapse_stream_name: The *redis* stream name to subscribe to and publish
|
||||
from (not anything to do with Synapse replication streams).
|
||||
synapse_outbound_redis_connection: The connection to redis to use to send
|
||||
commands.
|
||||
"""
|
||||
|
||||
handler = None # type: ReplicationCommandHandler
|
||||
stream_name = None # type: str
|
||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
synapse_handler = None # type: ReplicationCommandHandler
|
||||
synapse_stream_name = None # type: str
|
||||
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
# it's important to make sure that we only send the REPLICATE command once we
|
||||
# have successfully subscribed to the stream - otherwise we might miss the
|
||||
# POSITION response sent back by the other end.
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
|
||||
await make_deferred_yieldable(self.subscribe(self.stream_name))
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
|
||||
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
||||
logger.info(
|
||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||
)
|
||||
self.handler.new_connection(self)
|
||||
self.synapse_handler.new_connection(self)
|
||||
await self._async_send_command(ReplicateCommand())
|
||||
logger.info("REPLICATE successfully sent")
|
||||
|
||||
# We send out our positions when there is a new connection in case the
|
||||
# other side missed updates. We do this for Redis connections as the
|
||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||
self.handler.send_positions_to_connection(self)
|
||||
self.synapse_handler.send_positions_to_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
|
@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
cmd: received command
|
||||
"""
|
||||
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if not cmd_func:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
return
|
||||
|
@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis")
|
||||
super().connectionLost(reason)
|
||||
self.handler.lost_connection(self)
|
||||
self.synapse_handler.lost_connection(self)
|
||||
|
||||
# mark the logging context as finished
|
||||
self._logging_context.__exit__(None, None, None)
|
||||
|
@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
|
||||
self.synapse_outbound_redis_connection.publish(
|
||||
self.synapse_stream_name, encoded_string
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
"""A subclass of RedisFactory that periodically sends pings to ensure that
|
||||
we detect dead connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
uuid: str,
|
||||
dbid: Optional[int],
|
||||
poolsize: int,
|
||||
isLazy: bool = False,
|
||||
handler: Type = txredisapi.ConnectionHandler,
|
||||
charset: str = "utf-8",
|
||||
password: Optional[str] = None,
|
||||
replyTimeout: int = 30,
|
||||
convertNumbers: Optional[int] = True,
|
||||
):
|
||||
super().__init__(
|
||||
uuid=uuid,
|
||||
dbid=dbid,
|
||||
poolsize=poolsize,
|
||||
isLazy=isLazy,
|
||||
handler=handler,
|
||||
charset=charset,
|
||||
password=password,
|
||||
replyTimeout=replyTimeout,
|
||||
convertNumbers=convertNumbers,
|
||||
)
|
||||
|
||||
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||
|
||||
@wrap_as_background_process("redis_ping")
|
||||
async def _send_ping(self):
|
||||
for connection in self.pool:
|
||||
try:
|
||||
await make_deferred_yieldable(connection.ping())
|
||||
except Exception:
|
||||
logger.warning("Failed to send ping to a redis connection")
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||
"""This is a reconnecting factory that connects to redis and immediately
|
||||
subscribes to a stream.
|
||||
|
||||
|
@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
|||
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
hs,
|
||||
uuid="subscriber",
|
||||
dbid=None,
|
||||
poolsize=1,
|
||||
replyTimeout=30,
|
||||
password=hs.config.redis.redis_password,
|
||||
)
|
||||
|
||||
# This sets the password on the RedisFactory base class (as
|
||||
# SubscriberFactory constructor doesn't pass it through).
|
||||
self.password = hs.config.redis.redis_password
|
||||
self.synapse_handler = hs.get_tcp_replication()
|
||||
self.synapse_stream_name = hs.hostname
|
||||
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.stream_name = hs.hostname
|
||||
|
||||
self.outbound_redis_connection = outbound_redis_connection
|
||||
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = super().buildProtocol(addr) # type: RedisSubscriber
|
||||
p = super().buildProtocol(addr)
|
||||
p = cast(RedisSubscriber, p)
|
||||
|
||||
# We do this here rather than add to the constructor of `RedisSubcriber`
|
||||
# as to do so would involve overriding `buildProtocol` entirely, however
|
||||
# the base method does some other things than just instantiating the
|
||||
# protocol.
|
||||
p.handler = self.handler
|
||||
p.outbound_redis_connection = self.outbound_redis_connection
|
||||
p.stream_name = self.stream_name
|
||||
p.password = self.password
|
||||
p.synapse_handler = self.synapse_handler
|
||||
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
|
||||
p.synapse_stream_name = self.synapse_stream_name
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def lazyConnection(
|
||||
reactor,
|
||||
hs: "HomeServer",
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
dbid: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
charset: str = "utf-8",
|
||||
password: Optional[str] = None,
|
||||
connectTimeout: Optional[int] = None,
|
||||
replyTimeout: Optional[int] = None,
|
||||
convertNumbers: bool = True,
|
||||
replyTimeout: int = 30,
|
||||
) -> txredisapi.RedisProtocol:
|
||||
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
|
||||
reactor.
|
||||
"""Creates a connection to Redis that is lazily set up and reconnects if the
|
||||
connections is lost.
|
||||
"""
|
||||
|
||||
isLazy = True
|
||||
poolsize = 1
|
||||
|
||||
uuid = "%s:%d" % (host, port)
|
||||
factory = txredisapi.RedisFactory(
|
||||
uuid,
|
||||
dbid,
|
||||
poolsize,
|
||||
isLazy,
|
||||
txredisapi.ConnectionHandler,
|
||||
charset,
|
||||
password,
|
||||
replyTimeout,
|
||||
convertNumbers,
|
||||
factory = SynapseRedisFactory(
|
||||
hs,
|
||||
uuid=uuid,
|
||||
dbid=dbid,
|
||||
poolsize=1,
|
||||
isLazy=True,
|
||||
handler=txredisapi.ConnectionHandler,
|
||||
password=password,
|
||||
replyTimeout=replyTimeout,
|
||||
)
|
||||
factory.continueTrying = reconnect
|
||||
for x in range(poolsize):
|
||||
reactor.connectTCP(host, port, factory, connectTimeout)
|
||||
|
||||
reactor = hs.get_reactor()
|
||||
reactor.connectTCP(host, port, factory, 30)
|
||||
|
||||
return factory.handler
|
||||
|
|
Loading…
Reference in New Issue