Add missing type hints to `synapse.appservice` (#11360)
This commit is contained in:
parent
70ca05373b
commit
2519beaad2
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to `synapse.appservice`.
|
3
mypy.ini
3
mypy.ini
|
@ -143,6 +143,9 @@ disallow_untyped_defs = True
|
||||||
[mypy-synapse.app.*]
|
[mypy-synapse.app.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.appservice.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.config._base]
|
[mypy-synapse.config._base]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from netaddr import IPSet
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
|
||||||
UP = "up"
|
UP = "up"
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class Namespace:
|
||||||
|
exclusive: bool
|
||||||
|
group_id: Optional[str]
|
||||||
|
regex: Pattern[str]
|
||||||
|
|
||||||
|
|
||||||
class ApplicationService:
|
class ApplicationService:
|
||||||
"""Defines an application service. This definition is mostly what is
|
"""Defines an application service. This definition is mostly what is
|
||||||
provided to the /register AS API.
|
provided to the /register AS API.
|
||||||
|
@ -50,17 +61,17 @@ class ApplicationService:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
token,
|
token: str,
|
||||||
hostname,
|
hostname: str,
|
||||||
id,
|
id: str,
|
||||||
sender,
|
sender: str,
|
||||||
url=None,
|
url: Optional[str] = None,
|
||||||
namespaces=None,
|
namespaces: Optional[JsonDict] = None,
|
||||||
hs_token=None,
|
hs_token: Optional[str] = None,
|
||||||
protocols=None,
|
protocols: Optional[Iterable[str]] = None,
|
||||||
rate_limited=True,
|
rate_limited: bool = True,
|
||||||
ip_range_whitelist=None,
|
ip_range_whitelist: Optional[IPSet] = None,
|
||||||
supports_ephemeral=False,
|
supports_ephemeral: bool = False,
|
||||||
):
|
):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.url = (
|
self.url = (
|
||||||
|
@ -85,27 +96,33 @@ class ApplicationService:
|
||||||
|
|
||||||
self.rate_limited = rate_limited
|
self.rate_limited = rate_limited
|
||||||
|
|
||||||
def _check_namespaces(self, namespaces):
|
def _check_namespaces(
|
||||||
|
self, namespaces: Optional[JsonDict]
|
||||||
|
) -> Dict[str, List[Namespace]]:
|
||||||
# Sanity check that it is of the form:
|
# Sanity check that it is of the form:
|
||||||
# {
|
# {
|
||||||
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
||||||
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
||||||
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
||||||
# }
|
# }
|
||||||
if not namespaces:
|
if namespaces is None:
|
||||||
namespaces = {}
|
namespaces = {}
|
||||||
|
|
||||||
|
result: Dict[str, List[Namespace]] = {}
|
||||||
|
|
||||||
for ns in ApplicationService.NS_LIST:
|
for ns in ApplicationService.NS_LIST:
|
||||||
|
result[ns] = []
|
||||||
|
|
||||||
if ns not in namespaces:
|
if ns not in namespaces:
|
||||||
namespaces[ns] = []
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if type(namespaces[ns]) != list:
|
if not isinstance(namespaces[ns], list):
|
||||||
raise ValueError("Bad namespace value for '%s'" % ns)
|
raise ValueError("Bad namespace value for '%s'" % ns)
|
||||||
for regex_obj in namespaces[ns]:
|
for regex_obj in namespaces[ns]:
|
||||||
if not isinstance(regex_obj, dict):
|
if not isinstance(regex_obj, dict):
|
||||||
raise ValueError("Expected dict regex for ns '%s'" % ns)
|
raise ValueError("Expected dict regex for ns '%s'" % ns)
|
||||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
exclusive = regex_obj.get("exclusive")
|
||||||
|
if not isinstance(exclusive, bool):
|
||||||
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
|
raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
|
||||||
group_id = regex_obj.get("group_id")
|
group_id = regex_obj.get("group_id")
|
||||||
if group_id:
|
if group_id:
|
||||||
|
@ -126,22 +143,26 @@ class ApplicationService:
|
||||||
)
|
)
|
||||||
|
|
||||||
regex = regex_obj.get("regex")
|
regex = regex_obj.get("regex")
|
||||||
if isinstance(regex, str):
|
if not isinstance(regex, str):
|
||||||
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
|
||||||
else:
|
|
||||||
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
|
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
|
||||||
return namespaces
|
|
||||||
|
|
||||||
def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
|
# Pre-compile regex.
|
||||||
for regex_obj in self.namespaces[namespace_key]:
|
result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
|
||||||
if regex_obj["regex"].match(test_string):
|
|
||||||
return regex_obj
|
return result
|
||||||
|
|
||||||
|
def _matches_regex(
|
||||||
|
self, namespace_key: str, test_string: str
|
||||||
|
) -> Optional[Namespace]:
|
||||||
|
for namespace in self.namespaces[namespace_key]:
|
||||||
|
if namespace.regex.match(test_string):
|
||||||
|
return namespace
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
|
def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
|
||||||
regex_obj = self._matches_regex(test_string, ns_key)
|
namespace = self._matches_regex(namespace_key, test_string)
|
||||||
if regex_obj:
|
if namespace:
|
||||||
return regex_obj["exclusive"]
|
return namespace.exclusive
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _matches_user(
|
async def _matches_user(
|
||||||
|
@ -260,15 +281,15 @@ class ApplicationService:
|
||||||
|
|
||||||
def is_interested_in_user(self, user_id: str) -> bool:
|
def is_interested_in_user(self, user_id: str) -> bool:
|
||||||
return (
|
return (
|
||||||
bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
|
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
|
||||||
or user_id == self.sender
|
or user_id == self.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_interested_in_alias(self, alias: str) -> bool:
|
def is_interested_in_alias(self, alias: str) -> bool:
|
||||||
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
|
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
|
||||||
|
|
||||||
def is_interested_in_room(self, room_id: str) -> bool:
|
def is_interested_in_room(self, room_id: str) -> bool:
|
||||||
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
|
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
|
||||||
|
|
||||||
def is_exclusive_user(self, user_id: str) -> bool:
|
def is_exclusive_user(self, user_id: str) -> bool:
|
||||||
return (
|
return (
|
||||||
|
@ -285,14 +306,14 @@ class ApplicationService:
|
||||||
def is_exclusive_room(self, room_id: str) -> bool:
|
def is_exclusive_room(self, room_id: str) -> bool:
|
||||||
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
|
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
|
||||||
|
|
||||||
def get_exclusive_user_regexes(self):
|
def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
|
||||||
"""Get the list of regexes used to determine if a user is exclusively
|
"""Get the list of regexes used to determine if a user is exclusively
|
||||||
registered by the AS
|
registered by the AS
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
regex_obj["regex"]
|
namespace.regex
|
||||||
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
|
for namespace in self.namespaces[ApplicationService.NS_USERS]
|
||||||
if regex_obj["exclusive"]
|
if namespace.exclusive
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
|
def get_groups_for_user(self, user_id: str) -> Iterable[str]:
|
||||||
|
@ -305,15 +326,15 @@ class ApplicationService:
|
||||||
An iterable that yields group_id strings.
|
An iterable that yields group_id strings.
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
regex_obj["group_id"]
|
namespace.group_id
|
||||||
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
|
for namespace in self.namespaces[ApplicationService.NS_USERS]
|
||||||
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
|
if namespace.group_id and namespace.regex.match(user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_rate_limited(self) -> bool:
|
def is_rate_limited(self) -> bool:
|
||||||
return self.rate_limited
|
return self.rate_limited
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
# copy dictionary and redact token fields so they don't get logged
|
# copy dictionary and redact token fields so they don't get logged
|
||||||
dict_copy = self.__dict__.copy()
|
dict_copy = self.__dict__.copy()
|
||||||
dict_copy["token"] = "<redacted>"
|
dict_copy["token"] = "<redacted>"
|
||||||
|
|
|
@ -12,8 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib.parse
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000
|
||||||
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_3pe_metadata(info):
|
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
|
||||||
if "instances" not in info:
|
if "instances" not in info:
|
||||||
return False
|
return False
|
||||||
if not isinstance(info["instances"], list):
|
if not isinstance(info["instances"], list):
|
||||||
|
@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_3pe_result(r, field):
|
def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
|
||||||
if not isinstance(r, dict):
|
if not isinstance(r, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
async def query_user(self, service, user_id):
|
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
|
||||||
if service.url is None:
|
if service.url is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# This is required by the configuration.
|
||||||
|
assert service.hs_token is not None
|
||||||
|
|
||||||
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
|
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
|
||||||
try:
|
try:
|
||||||
response = await self.get_json(uri, {"access_token": service.hs_token})
|
response = await self.get_json(uri, {"access_token": service.hs_token})
|
||||||
|
@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
logger.warning("query_user to %s threw exception %s", uri, ex)
|
logger.warning("query_user to %s threw exception %s", uri, ex)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def query_alias(self, service, alias):
|
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
|
||||||
if service.url is None:
|
if service.url is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# This is required by the configuration.
|
||||||
|
assert service.hs_token is not None
|
||||||
|
|
||||||
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
|
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
|
||||||
try:
|
try:
|
||||||
response = await self.get_json(uri, {"access_token": service.hs_token})
|
response = await self.get_json(uri, {"access_token": service.hs_token})
|
||||||
|
@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def query_3pe(self, service, kind, protocol, fields):
|
async def query_3pe(
|
||||||
|
self,
|
||||||
|
service: "ApplicationService",
|
||||||
|
kind: str,
|
||||||
|
protocol: str,
|
||||||
|
fields: Dict[bytes, List[bytes]],
|
||||||
|
) -> List[JsonDict]:
|
||||||
if kind == ThirdPartyEntityKind.USER:
|
if kind == ThirdPartyEntityKind.USER:
|
||||||
required_field = "userid"
|
required_field = "userid"
|
||||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||||
|
@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: List[JsonDict],
|
ephemeral: List[JsonDict],
|
||||||
txn_id: Optional[int] = None,
|
txn_id: Optional[int] = None,
|
||||||
):
|
) -> bool:
|
||||||
if service.url is None:
|
if service.url is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
events = self._serialize(service, events)
|
# This is required by the configuration.
|
||||||
|
assert service.hs_token is not None
|
||||||
|
|
||||||
|
serialized_events = self._serialize(service, events)
|
||||||
|
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
|
|
||||||
# Never send ephemeral events to appservices that do not support it
|
# Never send ephemeral events to appservices that do not support it
|
||||||
if service.supports_ephemeral:
|
if service.supports_ephemeral:
|
||||||
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
|
body = {
|
||||||
|
"events": serialized_events,
|
||||||
|
"de.sorunome.msc2409.ephemeral": ephemeral,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
body = {"events": events}
|
body = {"events": serialized_events}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.put_json(
|
await self.put_json(
|
||||||
|
@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
[event.get("event_id") for event in events],
|
[event.get("event_id") for event in events],
|
||||||
)
|
)
|
||||||
sent_transactions_counter.labels(service.id).inc()
|
sent_transactions_counter.labels(service.id).inc()
|
||||||
sent_events_counter.labels(service.id).inc(len(events))
|
sent_events_counter.labels(service.id).inc(len(serialized_events))
|
||||||
return True
|
return True
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
failed_transactions_counter.labels(service.id).inc()
|
failed_transactions_counter.labels(service.id).inc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _serialize(self, service, events):
|
def _serialize(
|
||||||
|
self, service: "ApplicationService", events: Iterable[EventBase]
|
||||||
|
) -> List[JsonDict]:
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
return [
|
return [
|
||||||
serialize_event(
|
serialize_event(
|
||||||
|
|
|
@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required
|
||||||
components.
|
components.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, ApplicationServiceState
|
from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||||
|
from synapse.appservice.api import ApplicationServiceApi
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
|
||||||
case is a simple array.
|
case is a simple array.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.as_api = hs.get_application_service_api()
|
self.as_api = hs.get_application_service_api()
|
||||||
|
@ -80,7 +86,7 @@ class ApplicationServiceScheduler:
|
||||||
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
|
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
|
||||||
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self) -> None:
|
||||||
logger.info("Starting appservice scheduler")
|
logger.info("Starting appservice scheduler")
|
||||||
|
|
||||||
# check for any DOWN ASes and start recoverers for them.
|
# check for any DOWN ASes and start recoverers for them.
|
||||||
|
@ -91,12 +97,14 @@ class ApplicationServiceScheduler:
|
||||||
for service in services:
|
for service in services:
|
||||||
self.txn_ctrl.start_recoverer(service)
|
self.txn_ctrl.start_recoverer(service)
|
||||||
|
|
||||||
def submit_event_for_as(self, service: ApplicationService, event: EventBase):
|
def submit_event_for_as(
|
||||||
|
self, service: ApplicationService, event: EventBase
|
||||||
|
) -> None:
|
||||||
self.queuer.enqueue_event(service, event)
|
self.queuer.enqueue_event(service, event)
|
||||||
|
|
||||||
def submit_ephemeral_events_for_as(
|
def submit_ephemeral_events_for_as(
|
||||||
self, service: ApplicationService, events: List[JsonDict]
|
self, service: ApplicationService, events: List[JsonDict]
|
||||||
):
|
) -> None:
|
||||||
self.queuer.enqueue_ephemeral(service, events)
|
self.queuer.enqueue_ephemeral(service, events)
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,16 +116,18 @@ class _ServiceQueuer:
|
||||||
appservice at a given time.
|
appservice at a given time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, txn_ctrl, clock):
|
def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
|
||||||
self.queued_events = {} # dict of {service_id: [events]}
|
# dict of {service_id: [events]}
|
||||||
self.queued_ephemeral = {} # dict of {service_id: [events]}
|
self.queued_events: Dict[str, List[EventBase]] = {}
|
||||||
|
# dict of {service_id: [events]}
|
||||||
|
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
|
||||||
|
|
||||||
# the appservices which currently have a transaction in flight
|
# the appservices which currently have a transaction in flight
|
||||||
self.requests_in_flight = set()
|
self.requests_in_flight: Set[str] = set()
|
||||||
self.txn_ctrl = txn_ctrl
|
self.txn_ctrl = txn_ctrl
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
|
|
||||||
def _start_background_request(self, service):
|
def _start_background_request(self, service: ApplicationService) -> None:
|
||||||
# start a sender for this appservice if we don't already have one
|
# start a sender for this appservice if we don't already have one
|
||||||
if service.id in self.requests_in_flight:
|
if service.id in self.requests_in_flight:
|
||||||
return
|
return
|
||||||
|
@ -126,15 +136,17 @@ class _ServiceQueuer:
|
||||||
"as-sender-%s" % (service.id,), self._send_request, service
|
"as-sender-%s" % (service.id,), self._send_request, service
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_event(self, service: ApplicationService, event: EventBase):
|
def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
|
||||||
self.queued_events.setdefault(service.id, []).append(event)
|
self.queued_events.setdefault(service.id, []).append(event)
|
||||||
self._start_background_request(service)
|
self._start_background_request(service)
|
||||||
|
|
||||||
def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
|
def enqueue_ephemeral(
|
||||||
|
self, service: ApplicationService, events: List[JsonDict]
|
||||||
|
) -> None:
|
||||||
self.queued_ephemeral.setdefault(service.id, []).extend(events)
|
self.queued_ephemeral.setdefault(service.id, []).extend(events)
|
||||||
self._start_background_request(service)
|
self._start_background_request(service)
|
||||||
|
|
||||||
async def _send_request(self, service: ApplicationService):
|
async def _send_request(self, service: ApplicationService) -> None:
|
||||||
# sanity-check: we shouldn't get here if this service already has a sender
|
# sanity-check: we shouldn't get here if this service already has a sender
|
||||||
# running.
|
# running.
|
||||||
assert service.id not in self.requests_in_flight
|
assert service.id not in self.requests_in_flight
|
||||||
|
@ -168,20 +180,15 @@ class _TransactionController:
|
||||||
if a transaction fails.
|
if a transaction fails.
|
||||||
|
|
||||||
(Note we have only have one of these in the homeserver.)
|
(Note we have only have one of these in the homeserver.)
|
||||||
|
|
||||||
Args:
|
|
||||||
clock (synapse.util.Clock):
|
|
||||||
store (synapse.storage.DataStore):
|
|
||||||
as_api (synapse.appservice.api.ApplicationServiceApi):
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clock, store, as_api):
|
def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.store = store
|
self.store = store
|
||||||
self.as_api = as_api
|
self.as_api = as_api
|
||||||
|
|
||||||
# map from service id to recoverer instance
|
# map from service id to recoverer instance
|
||||||
self.recoverers = {}
|
self.recoverers: Dict[str, "_Recoverer"] = {}
|
||||||
|
|
||||||
# for UTs
|
# for UTs
|
||||||
self.RECOVERER_CLASS = _Recoverer
|
self.RECOVERER_CLASS = _Recoverer
|
||||||
|
@ -191,7 +198,7 @@ class _TransactionController:
|
||||||
service: ApplicationService,
|
service: ApplicationService,
|
||||||
events: List[EventBase],
|
events: List[EventBase],
|
||||||
ephemeral: Optional[List[JsonDict]] = None,
|
ephemeral: Optional[List[JsonDict]] = None,
|
||||||
):
|
) -> None:
|
||||||
try:
|
try:
|
||||||
txn = await self.store.create_appservice_txn(
|
txn = await self.store.create_appservice_txn(
|
||||||
service=service, events=events, ephemeral=ephemeral or []
|
service=service, events=events, ephemeral=ephemeral or []
|
||||||
|
@ -207,7 +214,7 @@ class _TransactionController:
|
||||||
logger.exception("Error creating appservice transaction")
|
logger.exception("Error creating appservice transaction")
|
||||||
run_in_background(self._on_txn_fail, service)
|
run_in_background(self._on_txn_fail, service)
|
||||||
|
|
||||||
async def on_recovered(self, recoverer):
|
async def on_recovered(self, recoverer: "_Recoverer") -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Successfully recovered application service AS ID %s", recoverer.service.id
|
"Successfully recovered application service AS ID %s", recoverer.service.id
|
||||||
)
|
)
|
||||||
|
@ -217,18 +224,18 @@ class _TransactionController:
|
||||||
recoverer.service, ApplicationServiceState.UP
|
recoverer.service, ApplicationServiceState.UP
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_txn_fail(self, service):
|
async def _on_txn_fail(self, service: ApplicationService) -> None:
|
||||||
try:
|
try:
|
||||||
await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||||
self.start_recoverer(service)
|
self.start_recoverer(service)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error starting AS recoverer")
|
logger.exception("Error starting AS recoverer")
|
||||||
|
|
||||||
def start_recoverer(self, service):
|
def start_recoverer(self, service: ApplicationService) -> None:
|
||||||
"""Start a Recoverer for the given service
|
"""Start a Recoverer for the given service
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service (synapse.appservice.ApplicationService):
|
service:
|
||||||
"""
|
"""
|
||||||
logger.info("Starting recoverer for AS ID %s", service.id)
|
logger.info("Starting recoverer for AS ID %s", service.id)
|
||||||
assert service.id not in self.recoverers
|
assert service.id not in self.recoverers
|
||||||
|
@ -257,7 +264,14 @@ class _Recoverer:
|
||||||
callback (callable[_Recoverer]): called once the service recovers.
|
callback (callable[_Recoverer]): called once the service recovers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clock, store, as_api, service, callback):
|
def __init__(
|
||||||
|
self,
|
||||||
|
clock: Clock,
|
||||||
|
store: DataStore,
|
||||||
|
as_api: ApplicationServiceApi,
|
||||||
|
service: ApplicationService,
|
||||||
|
callback: Callable[["_Recoverer"], Awaitable[None]],
|
||||||
|
):
|
||||||
self.clock = clock
|
self.clock = clock
|
||||||
self.store = store
|
self.store = store
|
||||||
self.as_api = as_api
|
self.as_api = as_api
|
||||||
|
@ -265,8 +279,8 @@ class _Recoverer:
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.backoff_counter = 1
|
self.backoff_counter = 1
|
||||||
|
|
||||||
def recover(self):
|
def recover(self) -> None:
|
||||||
def _retry():
|
def _retry() -> None:
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
"as-recoverer-%s" % (self.service.id,), self.retry
|
"as-recoverer-%s" % (self.service.id,), self.retry
|
||||||
)
|
)
|
||||||
|
@ -275,13 +289,13 @@ class _Recoverer:
|
||||||
logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
|
logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
|
||||||
self.clock.call_later(delay, _retry)
|
self.clock.call_later(delay, _retry)
|
||||||
|
|
||||||
def _backoff(self):
|
def _backoff(self) -> None:
|
||||||
# cap the backoff to be around 8.5min => (2^9) = 512 secs
|
# cap the backoff to be around 8.5min => (2^9) = 512 secs
|
||||||
if self.backoff_counter < 9:
|
if self.backoff_counter < 9:
|
||||||
self.backoff_counter += 1
|
self.backoff_counter += 1
|
||||||
self.recover()
|
self.recover()
|
||||||
|
|
||||||
async def retry(self):
|
async def retry(self) -> None:
|
||||||
logger.info("Starting retries on %s", self.service.id)
|
logger.info("Starting retries on %s", self.service.id)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -147,8 +147,7 @@ def _load_appservice(
|
||||||
# protocols check
|
# protocols check
|
||||||
protocols = as_info.get("protocols")
|
protocols = as_info.get("protocols")
|
||||||
if protocols:
|
if protocols:
|
||||||
# Because strings are lists in python
|
if not isinstance(protocols, list):
|
||||||
if isinstance(protocols, str) or not isinstance(protocols, list):
|
|
||||||
raise KeyError("Optional 'protocols' must be a list if present.")
|
raise KeyError("Optional 'protocols' must be a list if present.")
|
||||||
for p in protocols:
|
for p in protocols:
|
||||||
if not isinstance(p, str):
|
if not isinstance(p, str):
|
||||||
|
|
|
@ -16,13 +16,13 @@ from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService, Namespace
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
def _regex(regex, exclusive=True):
|
def _regex(regex: str, exclusive: bool = True) -> Namespace:
|
||||||
return {"regex": re.compile(regex), "exclusive": exclusive}
|
return Namespace(exclusive, None, re.compile(regex))
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceTestCase(unittest.TestCase):
|
class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
|
@ -33,11 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
url="some_url",
|
url="some_url",
|
||||||
token="some_token",
|
token="some_token",
|
||||||
hostname="matrix.org", # only used by get_groups_for_user
|
hostname="matrix.org", # only used by get_groups_for_user
|
||||||
namespaces={
|
|
||||||
ApplicationService.NS_USERS: [],
|
|
||||||
ApplicationService.NS_ROOMS: [],
|
|
||||||
ApplicationService.NS_ALIASES: [],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
self.event = Mock(
|
self.event = Mock(
|
||||||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||||
|
|
Loading…
Reference in New Issue