Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617)

Co-authored-by: Erik Johnston <erik@matrix.org>
This commit is contained in:
reivilibre 2022-02-24 17:55:45 +00:00 committed by GitHub
parent 41cf4c2cf6
commit 2cc5ea933d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 528 additions and 38 deletions

View File

@ -0,0 +1 @@
Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services.

View File

@ -31,6 +31,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Type for the `device_one_time_key_counts` field in an appservice transaction
# user ID -> {device ID -> {algorithm -> count}}
TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]]
# Type for the `device_unused_fallback_keys` field in an appservice transaction
# user ID -> {device ID -> [algorithm]}
TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]]
class ApplicationServiceState(Enum): class ApplicationServiceState(Enum):
DOWN = "down" DOWN = "down"
@ -72,6 +80,7 @@ class ApplicationService:
rate_limited: bool = True, rate_limited: bool = True,
ip_range_whitelist: Optional[IPSet] = None, ip_range_whitelist: Optional[IPSet] = None,
supports_ephemeral: bool = False, supports_ephemeral: bool = False,
msc3202_transaction_extensions: bool = False,
): ):
self.token = token self.token = token
self.url = ( self.url = (
@ -84,6 +93,7 @@ class ApplicationService:
self.id = id self.id = id
self.ip_range_whitelist = ip_range_whitelist self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral self.supports_ephemeral = supports_ephemeral
self.msc3202_transaction_extensions = msc3202_transaction_extensions
if "|" in self.id: if "|" in self.id:
raise Exception("application service ID cannot contain '|' character") raise Exception("application service ID cannot contain '|' character")
@ -339,12 +349,16 @@ class AppServiceTransaction:
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
): ):
self.service = service self.service = service
self.id = id self.id = id
self.events = events self.events = events
self.ephemeral = ephemeral self.ephemeral = ephemeral
self.to_device_messages = to_device_messages self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
async def send(self, as_api: "ApplicationServiceApi") -> bool: async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface. """Sends this transaction using the provided AS API interface.
@ -359,6 +373,8 @@ class AppServiceTransaction:
events=self.events, events=self.events,
ephemeral=self.ephemeral, ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages, to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
txn_id=self.id, txn_id=self.id,
) )

View File

@ -19,6 +19,11 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient):
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
txn_id: Optional[int] = None, txn_id: Optional[int] = None,
) -> bool: ) -> bool:
""" """
@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it # Never send ephemeral events to appservices that do not support it
body: Dict[str, List[JsonDict]] = {"events": serialized_events} body: JsonDict = {"events": serialized_events}
if service.supports_ephemeral: if service.supports_ephemeral:
body.update( body.update(
{ {
@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient):
} }
) )
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
"org.matrix.msc3202.device_one_time_key_counts"
] = one_time_key_counts
if unused_fallback_keys:
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
try: try:
await self.put_json( await self.put_json(
uri=uri, uri=uri,

View File

@ -54,12 +54,19 @@ from typing import (
Callable, Callable,
Collection, Collection,
Dict, Dict,
Iterable,
List, List,
Optional, Optional,
Set, Set,
Tuple,
) )
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
@ -96,7 +103,7 @@ class ApplicationServiceScheduler:
self.as_api = hs.get_application_service_api() self.as_api = hs.get_application_service_api()
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs)
async def start(self) -> None: async def start(self) -> None:
logger.info("Starting appservice scheduler") logger.info("Starting appservice scheduler")
@ -153,7 +160,9 @@ class _ServiceQueuer:
appservice at a given time. appservice at a given time.
""" """
def __init__(self, txn_ctrl: "_TransactionController", clock: Clock): def __init__(
self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer"
):
# dict of {service_id: [events]} # dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {} self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]} # dict of {service_id: [events]}
@ -165,6 +174,10 @@ class _ServiceQueuer:
self.requests_in_flight: Set[str] = set() self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock self.clock = clock
self._msc3202_transaction_extensions_enabled: bool = (
hs.config.experimental.msc3202_transaction_extensions
)
self._store = hs.get_datastores().main
def start_background_request(self, service: ApplicationService) -> None: def start_background_request(self, service: ApplicationService) -> None:
# start a sender for this appservice if we don't already have one # start a sender for this appservice if we don't already have one
@ -202,15 +215,84 @@ class _ServiceQueuer:
if not events and not ephemeral and not to_device_messages_to_send: if not events and not ephemeral and not to_device_messages_to_send:
return return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None
if (
self._msc3202_transaction_extensions_enabled
and service.msc3202_transaction_extensions
):
# Compute the one-time key counts and fallback key usage states
# for the users which are mentioned in this transaction,
# as well as the appservice's sender.
(
one_time_key_counts,
unused_fallback_keys,
) = await self._compute_msc3202_otk_counts_and_fallback_keys(
service, events, ephemeral, to_device_messages_to_send
)
try: try:
await self.txn_ctrl.send( await self.txn_ctrl.send(
service, events, ephemeral, to_device_messages_to_send service,
events,
ephemeral,
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
) )
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
finally: finally:
self.requests_in_flight.discard(service.id) self.requests_in_flight.discard(service.id)
async def _compute_msc3202_otk_counts_and_fallback_keys(
self,
service: ApplicationService,
events: Iterable[EventBase],
ephemerals: Iterable[JsonDict],
to_device_messages: Iterable[JsonDict],
) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
- first computes a list of application services users that may have
interesting updates to the one-time key counts or fallback key usage.
- then computes one-time key counts and fallback key usages for those users.
Given a list of application service users that are interesting,
compute one-time key counts and fallback key usages for the users.
"""
# Set of 'interesting' users who may have updates
users: Set[str] = set()
# The sender is always included
users.add(service.sender)
# All AS users that would receive the PDUs or EDUs sent to these rooms
# are classed as 'interesting'.
rooms_of_interesting_users: Set[str] = set()
# PDUs
rooms_of_interesting_users.update(event.room_id for event in events)
# EDUs
rooms_of_interesting_users.update(
ephemeral["room_id"] for ephemeral in ephemerals
)
# Look up the AS users in those rooms
for room_id in rooms_of_interesting_users:
users.update(
await self._store.get_app_service_users_in_room(room_id, service)
)
# Add recipients of to-device messages.
# device_message["user_id"] is the ID of the recipient.
users.update(device_message["user_id"] for device_message in to_device_messages)
# Compute and return the counts / fallback key usage states
otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users)
return otk_counts, unused_fbks
class _TransactionController: class _TransactionController:
"""Transaction manager. """Transaction manager.
@ -238,6 +320,8 @@ class _TransactionController:
events: List[EventBase], events: List[EventBase],
ephemeral: Optional[List[JsonDict]] = None, ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None, to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
) -> None: ) -> None:
""" """
Create a transaction with the given data and send to the provided Create a transaction with the given data and send to the provided
@ -248,6 +332,10 @@ class _TransactionController:
events: The persistent events to include in the transaction. events: The persistent events to include in the transaction.
ephemeral: The ephemeral events to include in the transaction. ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction. to_device_messages: The to-device messages to include in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
""" """
try: try:
txn = await self.store.create_appservice_txn( txn = await self.store.create_appservice_txn(
@ -255,6 +343,8 @@ class _TransactionController:
events=events, events=events,
ephemeral=ephemeral or [], ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [], to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
) )
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:

View File

@ -166,6 +166,16 @@ def _load_appservice(
supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
# Opt-in flag for the MSC3202-specific transactional behaviour.
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(
"The `org.matrix.msc3202` option should be true or false if specified."
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
hostname=hostname, hostname=hostname,
@ -174,8 +184,9 @@ def _load_appservice(
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols, protocols=protocols,
rate_limited=rate_limited, rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist, ip_range_whitelist=ip_range_whitelist,
supports_ephemeral=supports_ephemeral,
msc3202_transaction_extensions=msc3202_transaction_extensions,
) )

View File

@ -47,11 +47,6 @@ class ExperimentalConfig(Config):
# MSC3030 (Jump to date API endpoint) # MSC3030 (Jump to date API endpoint)
self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
# The portion of MSC3202 which is related to device masquerading.
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# MSC2409 (this setting only relates to optionally sending to-device messages). # MSC2409 (this setting only relates to optionally sending to-device messages).
# Presence, typing and read receipt EDUs are already sent to application services that # Presence, typing and read receipt EDUs are already sent to application services that
# have opted in to receive them. If enabled, this adds to-device messages to that list. # have opted in to receive them. If enabled, this adds to-device messages to that list.
@ -59,6 +54,17 @@ class ExperimentalConfig(Config):
"msc2409_to_device_messages_enabled", False "msc2409_to_device_messages_enabled", False
) )
# The portion of MSC3202 which is related to device masquerading.
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)
# MSC3706 (server-side support for partial state in /send_join responses) # MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)

View File

@ -20,14 +20,18 @@ from synapse.appservice import (
ApplicationService, ApplicationService,
ApplicationServiceState, ApplicationServiceState,
AppServiceTransaction, AppServiceTransaction,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
) )
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -56,7 +60,7 @@ def _make_exclusive_regex(
return exclusive_user_pattern return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore): class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return service return service
return None return None
@cached(iterable=True, cache_context=True)
async def get_app_service_users_in_room(
self,
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
users_in_room = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
return list(filter(app_service.is_interested_in_user, users_in_room))
class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceStore(ApplicationServiceWorkerStore):
# This is currently empty due to there not being any AS storage functions # This is currently empty due to there not being any AS storage functions
@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore(
events: List[EventBase], events: List[EventBase],
ephemeral: List[JsonDict], ephemeral: List[JsonDict],
to_device_messages: List[JsonDict], to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
) -> AppServiceTransaction: ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service """Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the with the given list of events. Ephemeral events are NOT persisted to the
@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore(
events: A list of persistent events to put in the transaction. events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction. to_device_messages: A list of to-device messages to put in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
Returns: Returns:
A new transaction. A new transaction.
@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore(
events=events, events=events,
ephemeral=ephemeral, ephemeral=ephemeral,
to_device_messages=to_device_messages, to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
) )
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys
# are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction( return AppServiceTransaction(
service=service, service=service,
id=entry["txn_id"], id=entry["txn_id"],
events=events, events=events,
ephemeral=[], ephemeral=[],
to_device_messages=[], to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
) )
def _get_last_txn(self, txn, service_id: Optional[str]) -> int: def _get_last_txn(self, txn, service_id: Optional[str]) -> int:

View File

@ -29,6 +29,10 @@ import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.api.constants import DeviceKeyAlgorithms from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
) -> TransactionOneTimeKeyCounts:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
transactions.
Return structure is of the shape:
user_id -> device_id -> algorithm -> count
Empty algorithm -> count dicts are created if needed to represent a
lack of unused one-time keys.
"""
def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
) -> TransactionOneTimeKeyCounts:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql = f"""
SELECT user_id, device_id, algorithm, COUNT(key_id)
FROM devices
LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
WHERE {user_in_where_clause}
GROUP BY user_id, device_id, algorithm
"""
txn.execute(sql, user_parameters)
result: TransactionOneTimeKeyCounts = {}
for user_id, device_id, algorithm, count in txn:
# We deliberately construct empty dictionaries for
# users and devices without any unused one-time keys.
# We *could* omit these empty dicts if there have been no
# changes since the last transaction, but we currently don't
# do any change tracking!
device_count_by_algo = result.setdefault(user_id, {}).setdefault(
device_id, {}
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_count_by_algo[algorithm] = count
return result
return await self.db_pool.runInteraction(
"count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
)
async def get_e2e_bulk_unused_fallback_key_types(
self, user_ids: Collection[str]
) -> TransactionUnusedFallbackKeys:
"""
Finds, in bulk, the types of unused fallback keys for all the users specified.
Intended to be used by application services for populating unused fallback
keys in transactions.
Return structure is of the shape:
user_id -> device_id -> algorithms
Empty lists are created for devices if there are no unused fallback
keys. This matches the response structure of MSC3202.
"""
if len(user_ids) == 0:
return {}
def _get_bulk_e2e_unused_fallback_keys_txn(
txn: LoggingTransaction,
) -> TransactionUnusedFallbackKeys:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "devices.user_id", user_ids
)
# We can't use USING here because we require the `.used` condition
# to be part of the JOIN condition so that we generate empty lists
# when all keys are used (as opposed to just when there are no keys at all).
sql = f"""
SELECT devices.user_id, devices.device_id, algorithm
FROM devices
LEFT JOIN e2e_fallback_keys_json AS fallback_keys
ON devices.user_id = fallback_keys.user_id
AND devices.device_id = fallback_keys.device_id
AND NOT fallback_keys.used
WHERE
{user_in_where_clause}
"""
txn.execute(sql, user_parameters)
result: TransactionUnusedFallbackKeys = {}
for user_id, device_id, algorithm in txn:
# We deliberately construct empty dictionaries and lists for
# users and devices without any unused fallback keys.
# We *could* omit these empty dicts if there have been no
# changes since the last transaction, but we currently don't
# do any change tracking!
device_unused_keys = result.setdefault(user_id, {}).setdefault(
device_id, []
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_unused_keys.append(algorithm)
return result
return await self.db_pool.runInteraction(
"_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn
)
async def set_e2e_fallback_keys( async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None: ) -> None:

View File

@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events, events=events,
ephemeral=[], ephemeral=[],
to_device_messages=[], # txn made and saved to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
) )
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed txn.complete.assert_called_once_with(self.store) # txn completed
@ -92,6 +94,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events, events=events,
ephemeral=[], ephemeral=[],
to_device_messages=[], # txn made and saved to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
) )
self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed self.assertEquals(0, txn.complete.call_count) # or completed
@ -114,7 +118,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events, ephemeral=[], to_device_messages=[] service=service,
events=events,
ephemeral=[],
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
) )
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4) service = Mock(id=4)
event = Mock() event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event]) self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
@ -231,11 +240,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately) # (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3]) self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], []) self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count) self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent # Resolve the send event: expect the queued events to be sent
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], None, None
)
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self): def test_multiple_service_queues(self):
@ -261,15 +272,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
# make sure callbacks for a service only send queued events for THAT # make sure callbacks for a service only send queued events for THAT
# service # service
srv_2_defer.callback(srv2) srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self): def test_send_large_txns(self):
@ -288,13 +299,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
self.scheduler.enqueue_for_appservice(service, [event], []) self.scheduler.enqueue_for_appservice(service, [event], [])
# Expect the first event to be sent immediately. # Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], None, None
)
srv_1_defer.callback(service) srv_1_defer.callback(service)
# Then send the next 100 events # Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], None, None
)
srv_2_defer.callback(service) srv_2_defer.callback(service)
# Then the final 99 events # Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], None, None
)
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self): def test_send_single_ephemeral_no_queue(self):
@ -302,14 +319,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event")] event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
)
def test_send_multiple_ephemeral_no_queue(self): def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
)
def test_send_single_ephemeral_with_queue(self): def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
@ -324,13 +345,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times. # Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count) self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send # Resolve txn_ctrl.send
d.callback(service) d.callback(service)
# Expect the queued events to be sent # Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with( self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, [] service, [], event_list_2 + event_list_3, [], None, None
) )
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], None, None
)
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)

View File

@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
import synapse.storage import synapse.storage
from synapse.appservice import ApplicationService from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, room, sendtodevice from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock from tests.test_utils import make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock from tests.utils import MockClock
@ -428,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# #
# The uninterested application service should not have been notified at all. # The uninterested application service should not have been notified at all.
self.send_mock.assert_called_once() self.send_mock.assert_called_once()
service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] (
service,
_events,
_ephemeral,
to_device_messages,
_otks,
_fbks,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent # Assert that this was the same to-device message that local_user sent
self.assertEqual(service, interested_appservice) self.assertEqual(service, interested_appservice)
@ -540,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {} service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list: for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages = call[0] service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
# Check that this was made to an interested service # Check that this was made to an interested service
self.assertIn(service, interested_appservices) self.assertIn(service, interested_appservices)
@ -582,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self._services.append(appservice) self._services.append(appservice)
return appservice return appservice
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4
ARG_FALLBACK_KEYS = 5
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
register.register_servlets,
room.register_servlets,
sendtodevice.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
self.send_mock = simple_async_mock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
# Define an application service for the tests
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
"as1.invalid",
"as1",
"@as.sender:test",
namespaces={
"users": [
{"regex": "@_as_.*:test", "exclusive": True},
{"regex": "@as.sender:test", "exclusive": True},
]
},
msc3202_transaction_extensions=True,
)
self.hs.get_datastores().main.services_cache = [self._service]
# Register some appservice users
self._sender_user, self._sender_device = self.register_appservice_user(
"as.sender", self._service_token
)
self._namespaced_user, self._namespaced_device = self.register_appservice_user(
"_as_user1", self._service_token
)
# Register a real user as well.
self._real_user = self.register_user("real.user", "meow")
self._real_user_token = self.login("real.user", "meow")
async def _add_otks_for_device(
self, user_id: str, device_id: str, otk_count: int
) -> None:
"""
Add some dummy keys. It doesn't matter if they're not a real algorithm;
that should be opaque to the server anyway.
"""
await self.hs.get_datastores().main.add_e2e_one_time_keys(
user_id,
device_id,
self.clock.time_msec(),
[("algo", f"k{i}", "{}") for i in range(otk_count)],
)
async def _add_fallback_key_for_device(
self, user_id: str, device_id: str, used: bool
) -> None:
"""
Adds a fake fallback key to a device, optionally marking it as used
right away.
"""
store = self.hs.get_datastores().main
await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"})
if used is True:
# Mark the key as used
await store.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": "algo",
"key_id": "fk",
},
updatevalues={"used": True},
desc="_get_fallback_key_set_used",
)
def _set_up_devices_and_a_room(self) -> str:
"""
Helper to set up devices for all the users
and a room for the users to talk in.
"""
async def preparation():
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True
)
await self._add_otks_for_device(
self._namespaced_user, self._namespaced_device, 36
)
await self._add_fallback_key_for_device(
self._namespaced_user, self._namespaced_device, used=False
)
# Register a device for the real user, too, so that we can later ensure
# that we don't leak information to the AS about the non-AS user.
await self.hs.get_datastores().main.store_device(
self._real_user, "REALDEV", "UltraMatrix 3000"
)
await self._add_otks_for_device(self._real_user, "REALDEV", 50)
self.get_success(preparation())
room_id = self.helper.create_room_as(
self._real_user, is_public=True, tok=self._real_user_token
)
self.helper.join(
room_id,
self._namespaced_user,
tok=self._service_token,
appservice_user_id=self._namespaced_user,
)
# Check it was called for sanity. (This was to send the join event to the AS.)
self.send_mock.assert_called()
self.send_mock.reset_mock()
return room_id
@override_config(
{"experimental_features": {"msc3202_transaction_extensions": True}}
)
def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus(
self,
) -> None:
"""
Tests that:
- the AS receives one-time key counts and unused fallback keys for:
- the specified sender; and
- any user who is in receipt of the PDUs
"""
room_id = self._set_up_devices_and_a_room()
# Send a message into the AS's room
self.helper.send(room_id, "woof woof", tok=self._real_user_token)
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
self.assertEqual(
otks,
{
"@as.sender:test": {self._sender_device: {"algo": 42}},
"@_as_user1:test": {self._namespaced_device: {"algo": 36}},
},
)
self.assertEqual(
unused_fallbacks,
{
"@as.sender:test": {self._sender_device: []},
"@_as_user1:test": {self._namespaced_device: ["algo"]},
},
)

View File

@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success( txn = self.get_success(
defer.ensureDeferred( defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], []) self.store.create_appservice_txn(service, events, [], [], {}, {})
) )
) )
self.assertEquals(txn.id, 1) self.assertEquals(txn.id, 1)
@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events)) self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], []) self.store.create_appservice_txn(service, events, [], [], {}, {})
) )
self.assertEquals(txn.id, 9646) self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643)) self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], []) self.store.create_appservice_txn(service, events, [], [], {}, {})
) )
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success( txn = self.get_success(
self.store.create_appservice_txn(service, events, [], []) self.store.create_appservice_txn(service, events, [], [], {}, {})
) )
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)