Add experimental support for MSC3391: deleting account data (#14714)

This commit is contained in:
Andrew Morgan 2023-01-01 03:40:46 +00:00 committed by GitHub
parent 044fa1a1de
commit c4456114e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 547 additions and 31 deletions

View File

@ -0,0 +1 @@
Add experimental support for [MSC3391](https://github.com/matrix-org/matrix-spec-proposals/pull/3391) (removing account data).

View File

@ -102,6 +102,8 @@ experimental_features:
{% endif %} {% endif %}
# Filtering /messages by relation type. # Filtering /messages by relation type.
msc3874_enabled: true msc3874_enabled: true
# Enable removing account data support
msc3391_enabled: true
server_notices: server_notices:
system_mxid_localpart: _server system_mxid_localpart: _server

View File

@ -190,7 +190,7 @@ fi
extra_test_args=() extra_test_args=()
test_tags="synapse_blacklist,msc3787,msc3874" test_tags="synapse_blacklist,msc3787,msc3874,msc3391"
# All environment variables starting with PASS_ will be shared. # All environment variables starting with PASS_ will be shared.
# (The prefix is stripped off before reaching the container.) # (The prefix is stripped off before reaching the container.)

View File

@ -136,3 +136,6 @@ class ExperimentalConfig(Config):
# Enable room version (and thus applicable push rules from MSC3931/3932) # Enable room version (and thus applicable push rules from MSC3931/3932)
version_id = RoomVersions.MSC1767v10.identifier version_id = RoomVersions.MSC1767v10.identifier
KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)

View File

@ -17,10 +17,12 @@ import random
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import ( from synapse.replication.http.account_data import (
ReplicationAddRoomAccountDataRestServlet,
ReplicationAddTagRestServlet, ReplicationAddTagRestServlet,
ReplicationAddUserAccountDataRestServlet,
ReplicationRemoveRoomAccountDataRestServlet,
ReplicationRemoveTagRestServlet, ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet, ReplicationRemoveUserAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
) )
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import JsonDict, StreamKeyType, UserID from synapse.types import JsonDict, StreamKeyType, UserID
@ -41,8 +43,18 @@ class AccountDataHandler:
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) self._add_user_data_client = (
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) ReplicationAddUserAccountDataRestServlet.make_client(hs)
)
self._remove_user_data_client = (
ReplicationRemoveUserAccountDataRestServlet.make_client(hs)
)
self._add_room_data_client = (
ReplicationAddRoomAccountDataRestServlet.make_client(hs)
)
self._remove_room_data_client = (
ReplicationRemoveRoomAccountDataRestServlet.make_client(hs)
)
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data self._account_data_writers = hs.config.worker.writers.account_data
@ -112,7 +124,7 @@ class AccountDataHandler:
return max_stream_id return max_stream_id
else: else:
response = await self._room_data_client( response = await self._add_room_data_client(
instance_name=random.choice(self._account_data_writers), instance_name=random.choice(self._account_data_writers),
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
@ -121,15 +133,59 @@ class AccountDataHandler:
) )
return response["max_stream_id"] return response["max_stream_id"]
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[int]:
"""
Deletes the room account data for the given user and account data type.
"Deleting" account data merely means setting the content of the account data
to an empty JSON object: {}.
Args:
user_id: The user ID to remove room account data for.
room_id: The room ID to target.
account_data_type: The account data type to remove.
Returns:
The maximum stream ID, or None if the room account data item did not exist.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_account_data_for_room(
user_id, room_id, account_data_type
)
if max_stream_id is None:
# The referenced account data did not exist, so no delete occurred.
return None
self._notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
# Notify Synapse modules that the content of the type has changed to an
# empty dictionary.
await self._notify_modules(user_id, room_id, account_data_type, {})
return max_stream_id
else:
response = await self._remove_room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
account_data_type=account_data_type,
content={},
)
return response["max_stream_id"]
async def add_account_data_for_user( async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict self, user_id: str, account_data_type: str, content: JsonDict
) -> int: ) -> int:
"""Add some global account_data for a user. """Add some global account_data for a user.
Args: Args:
user_id: The user to add a tag for. user_id: The user to add some account data for.
account_data_type: The type of account_data to add. account_data_type: The type of account_data to add.
content: A json object to associate with the tag. content: The content json dictionary.
Returns: Returns:
The maximum stream ID. The maximum stream ID.
@ -148,7 +204,7 @@ class AccountDataHandler:
return max_stream_id return max_stream_id
else: else:
response = await self._user_data_client( response = await self._add_user_data_client(
instance_name=random.choice(self._account_data_writers), instance_name=random.choice(self._account_data_writers),
user_id=user_id, user_id=user_id,
account_data_type=account_data_type, account_data_type=account_data_type,
@ -156,6 +212,45 @@ class AccountDataHandler:
) )
return response["max_stream_id"] return response["max_stream_id"]
async def remove_account_data_for_user(
self, user_id: str, account_data_type: str
) -> Optional[int]:
"""Removes a piece of global account_data for a user.
Args:
user_id: The user to remove account data for.
account_data_type: The type of account_data to remove.
Returns:
The maximum stream ID, or None if the room account data item did not exist.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_account_data_for_user(
user_id, account_data_type
)
if max_stream_id is None:
# The referenced account data did not exist, so no delete occurred.
return None
self._notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
)
# Notify Synapse modules that the content of the type has changed to an
# empty dictionary.
await self._notify_modules(user_id, None, account_data_type, {})
return max_stream_id
else:
response = await self._remove_user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
content={},
)
return response["max_stream_id"]
async def add_tag_to_room( async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int: ) -> int:

View File

@ -28,7 +28,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker. """Add user account data on the appropriate account data worker.
Request format: Request format:
@ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore[override] async def _serialize_payload( # type: ignore[override]
@ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
"""Remove user account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_user_account_data/:user_id/:type
{
"content": { ... },
}
"""
NAME = "remove_user_account_data"
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str, account_data_type: str
) -> JsonDict:
return {}
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_user(
user_id, account_data_type
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker. """Add room account data on the appropriate account data worker.
Request format: Request format:
@ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore[override] async def _serialize_payload( # type: ignore[override]
@ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}
class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
"""Remove room account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type
{
"content": { ... },
}
"""
NAME = "remove_room_account_data"
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
return {}
async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddTagRestServlet(ReplicationEndpoint): class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker. """Add tag on the appropriate account data worker.
@ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore[override] async def _serialize_payload( # type: ignore[override]
@ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
@ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationAddUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server) ReplicationRemoveTagRestServlet(hs).register(http_server)
if hs.config.experimental.msc3391_enabled:
ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server)
ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server)

View File

@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# If experimental support for MSC3391 is enabled, then providing an empty dict
# as the value for an account data type should be functionally equivalent to
# calling the DELETE method on the same type.
if self._hs.config.experimental.msc3391_enabled:
if body == {}:
await self.handler.remove_account_data_for_user(
user_id, account_data_type
)
return 200, {}
await self.handler.add_account_data_for_user(user_id, account_data_type, body) await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {} return 200, {}
@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet):
if event is None: if event is None:
raise NotFoundError("Account data not found") raise NotFoundError("Account data not found")
# If experimental support for MSC3391 is enabled, then this endpoint should
# return a 404 if the content for an account data type is an empty dict.
if self._hs.config.experimental.msc3391_enabled and event == {}:
raise NotFoundError("Account data not found")
return 200, event return 200, event
class UnstableAccountDataServlet(RestServlet):
"""
Contains an unstable endpoint for removing user account data, as specified by
MSC3391. If that MSC is accepted, this code should have unstable prefixes removed
and become incorporated into AccountDataServlet above.
"""
PATTERNS = client_patterns(
"/org.matrix.msc3391/user/(?P<user_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)",
unstable=True,
releases=(),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
async def on_DELETE(
self,
request: SynapseRequest,
user_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot delete account data for other users.")
await self.handler.remove_account_data_for_user(user_id, account_data_type)
return 200, {}
class RoomAccountDataServlet(RestServlet): class RoomAccountDataServlet(RestServlet):
""" """
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
# If experimental support for MSC3391 is enabled, then providing an empty dict
# as the value for an account data type should be functionally equivalent to
# calling the DELETE method on the same type.
if self._hs.config.experimental.msc3391_enabled:
if body == {}:
await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
)
return 200, {}
await self.handler.add_account_data_to_room( await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet):
if event is None: if event is None:
raise NotFoundError("Room account data not found") raise NotFoundError("Room account data not found")
# If experimental support for MSC3391 is enabled, then this endpoint should
# return a 404 if the content for an account data type is an empty dict.
if self._hs.config.experimental.msc3391_enabled and event == {}:
raise NotFoundError("Room account data not found")
return 200, event return 200, event
class UnstableRoomAccountDataServlet(RestServlet):
"""
Contains an unstable endpoint for removing room account data, as specified by
MSC3391. If that MSC is accepted, this code should have unstable prefixes removed
and become incorporated into RoomAccountDataServlet above.
"""
PATTERNS = client_patterns(
"/org.matrix.msc3391/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)",
unstable=True,
releases=(),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
async def on_DELETE(
self,
request: SynapseRequest,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot delete account data for other users.")
if not RoomID.is_valid(room_id):
raise SynapseError(
400,
f"{room_id} is not a valid room ID",
Codes.INVALID_PARAM,
)
await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server) AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server)
if hs.config.experimental.msc3391_enabled:
UnstableAccountDataServlet(hs).register(http_server)
UnstableRoomAccountDataServlet(hs).register(http_server)

View File

@ -1762,7 +1762,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
Returns: Returns:
A list of dictionaries. A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
""" """
return await self.runInteraction( return await self.runInteraction(
desc, desc,
@ -1791,6 +1792,10 @@ class DatabasePool:
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
retcols: the names of the columns to return retcols: the names of the columns to return
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
@ -1898,6 +1903,19 @@ class DatabasePool:
updatevalues: Dict[str, Any], updatevalues: Dict[str, Any],
desc: str, desc: str,
) -> int: ) -> int:
"""
Update rows in the given database table.
If the given keyvalues don't match anything, nothing will be updated.
Args:
table: The database table to update.
keyvalues: A mapping of column name to value to match rows on.
updatevalues: A mapping of column name to value to replace in any matched rows.
desc: description of the transaction, for logging and metrics.
Returns:
The number of rows that were updated. Will be 0 if no matching rows were found.
"""
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues desc, self.simple_update_txn, table, keyvalues, updatevalues
) )
@ -1909,6 +1927,19 @@ class DatabasePool:
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any], updatevalues: Dict[str, Any],
) -> int: ) -> int:
"""
Update rows in the given database table.
If the given keyvalues don't match anything, nothing will be updated.
Args:
txn: The database transaction object.
table: The database table to update.
keyvalues: A mapping of column name to value to match rows on.
updatevalues: A mapping of column name to value to replace in any matched rows.
Returns:
The number of rows that were updated. Will be 0 if no matching rows were found.
"""
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else: else:

View File

@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
async def get_account_data_for_user( async def get_account_data_for_user(
self, user_id: str self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user. """
Get all the client account_data for a user.
If experimental MSC3391 support is enabled, any entries with an empty
content body are excluded; as this means they have been deleted.
Args: Args:
user_id: The user to get the account_data for. user_id: The user to get the account_data for.
@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def get_account_data_for_user_txn( def get_account_data_for_user_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
rows = self.db_pool.simple_select_list_txn( # The 'content != '{}' condition below prevents us from using
txn, # `simple_select_list_txn` here, as it doesn't support conditions
"account_data", # other than 'equals'.
{"user_id": user_id}, sql = """
["account_data_type", "content"], SELECT account_data_type, content FROM account_data
) WHERE user_id = ?
"""
# If experimental MSC3391 support is enabled, then account data entries
# with an empty content are considered "deleted". So skip adding them to
# the results.
if self.hs.config.experimental.msc3391_enabled:
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
global_account_data = { global_account_data = {
row["account_data_type"]: db_to_json(row["content"]) for row in rows row["account_data_type"]: db_to_json(row["content"]) for row in rows
} }
rows = self.db_pool.simple_select_list_txn( # The 'content != '{}' condition below prevents us from using
txn, # `simple_select_list_txn` here, as it doesn't support conditions
"room_account_data", # other than 'equals'.
{"user_id": user_id}, sql = """
["room_id", "account_data_type", "content"], SELECT room_id, account_data_type, content FROM room_account_data
) WHERE user_id = ?
"""
# If experimental MSC3391 support is enabled, then account data entries
# with an empty content are considered "deleted". So skip adding them to
# the results.
if self.hs.config.experimental.msc3391_enabled:
sql += " AND content != '{}'"
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {} by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows: for row in rows:
room_data = by_room.setdefault(row["room_id"], {}) room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = db_to_json(row["content"]) room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room return global_account_data, by_room
@ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
async def remove_account_data_for_room(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[int]:
"""Delete the room account data for the user of a given type.
Args:
user_id: The user to remove account_data for.
room_id: The room ID to scope the request to.
account_data_type: The account data type to delete.
Returns:
The maximum stream position, or None if there was no matching room account
data to delete.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_room_txn(
txn: LoggingTransaction, next_id: int
) -> bool:
"""
Args:
txn: The transaction object.
next_id: The stream_id to update any existing rows to.
Returns:
True if an entry in room_account_data had its content set to '{}',
otherwise False. This informs callers of whether there actually was an
existing room account data entry to delete, or if the call was a no-op.
"""
# We can't use `simple_update` as it doesn't have the ability to specify
# where clauses other than '=', which we need for `content != '{}'` below.
sql = """
UPDATE room_account_data
SET stream_id = ?, content = '{}'
WHERE user_id = ?
AND room_id = ?
AND account_data_type = ?
AND content != '{}'
"""
txn.execute(
sql,
(next_id, user_id, room_id, account_data_type),
)
# Return true if any rows were updated.
return txn.rowcount != 0
async with self._account_data_id_gen.get_next() as next_id:
row_updated = await self.db_pool.runInteraction(
"remove_account_data_for_room",
_remove_account_data_for_room_txn,
next_id,
)
if not row_updated:
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), {}
)
return self._account_data_id_gen.get_current_token()
async def add_account_data_for_user( async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict self, user_id: str, account_data_type: str, content: JsonDict
) -> int: ) -> int:
@ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def remove_account_data_for_user(
self,
user_id: str,
account_data_type: str,
) -> Optional[int]:
"""
Delete a single piece of user account data by type.
A "delete" is performed by updating a potentially existing row in the
"account_data" database table for (user_id, account_data_type) and
setting its content to "{}".
Args:
user_id: The user ID to modify the account data of.
account_data_type: The type to remove.
Returns:
The maximum stream position, or None if there was no matching account data
to delete.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
def _remove_account_data_for_user_txn(
txn: LoggingTransaction, next_id: int
) -> bool:
"""
Args:
txn: The transaction object.
next_id: The stream_id to update any existing rows to.
Returns:
True if an entry in account_data had its content set to '{}', otherwise
False. This informs callers of whether there actually was an existing
account data entry to delete, or if the call was a no-op.
"""
# We can't use `simple_update` as it doesn't have the ability to specify
# where clauses other than '=', which we need for `content != '{}'` below.
sql = """
UPDATE account_data
SET stream_id = ?, content = '{}'
WHERE user_id = ?
AND account_data_type = ?
AND content != '{}'
"""
txn.execute(sql, (next_id, user_id, account_data_type))
if txn.rowcount == 0:
# We didn't update any rows. This means that there was no matching room
# account data entry to delete in the first place.
return False
# Ignored users get denormalized into a separate table as an optimisation.
if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
# If this method was called with the ignored users account data type, we
# simply delete all ignored users.
# First pull all the users that this user ignores.
previously_ignored_users = set(
self.db_pool.simple_select_onecol_txn(
txn,
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
retcol="ignored_user_id",
)
)
# Then delete them from the database.
self.db_pool.simple_delete_txn(
txn,
table="ignored_users",
keyvalues={"ignorer_user_id": user_id},
)
# Invalidate the cache for ignored users which were removed.
for ignored_user_id in previously_ignored_users:
self._invalidate_cache_and_stream(
txn, self.ignored_by, (ignored_user_id,)
)
# Invalidate for this user the cache tracking ignored users.
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
return True
async with self._account_data_id_gen.get_next() as next_id:
row_updated = await self.db_pool.runInteraction(
"remove_account_data_for_user",
_remove_account_data_for_user_txn,
next_id,
)
if not row_updated:
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.prefill(
(user_id, account_data_type), {}
)
return self._account_data_id_gen.get_current_token()
async def purge_account_data_for_user(self, user_id: str) -> None: async def purge_account_data_for_user(self, user_id: str) -> None:
""" """
Removes ALL the account data for a user. Removes ALL the account data for a user.