Implement MSC3816, consider the root event for thread participation. (#12766)
As opposed to only considering a user to have "participated" if they replied to the thread.
This commit is contained in:
parent
fcd8703508
commit
1acc897c31
|
@ -0,0 +1 @@
|
||||||
|
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
|
|
@ -12,16 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||||
TYPE_CHECKING,
|
|
||||||
Collection,
|
|
||||||
Dict,
|
|
||||||
FrozenSet,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
)
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -256,13 +247,19 @@ class RelationsHandler:
|
||||||
|
|
||||||
return filtered_results
|
return filtered_results
|
||||||
|
|
||||||
async def get_threads_for_events(
|
async def _get_threads_for_events(
|
||||||
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
|
self,
|
||||||
|
events_by_id: Dict[str, EventBase],
|
||||||
|
relations_by_id: Dict[str, str],
|
||||||
|
user_id: str,
|
||||||
|
ignored_users: FrozenSet[str],
|
||||||
) -> Dict[str, _ThreadAggregation]:
|
) -> Dict[str, _ThreadAggregation]:
|
||||||
"""Get the bundled aggregations for threads for the requested events.
|
"""Get the bundled aggregations for threads for the requested events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids: Events to get aggregations for threads.
|
events_by_id: A map of event_id to events to get aggregations for threads.
|
||||||
|
relations_by_id: A map of event_id to the relation type, if one exists
|
||||||
|
for that event.
|
||||||
user_id: The user requesting the bundled aggregations.
|
user_id: The user requesting the bundled aggregations.
|
||||||
ignored_users: The users ignored by the requesting user.
|
ignored_users: The users ignored by the requesting user.
|
||||||
|
|
||||||
|
@ -273,16 +270,34 @@ class RelationsHandler:
|
||||||
"""
|
"""
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
|
# It is not valid to start a thread on an event which itself relates to another event.
|
||||||
|
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
|
||||||
|
|
||||||
# Fetch thread summaries.
|
# Fetch thread summaries.
|
||||||
summaries = await self._main_store.get_thread_summaries(event_ids)
|
summaries = await self._main_store.get_thread_summaries(event_ids)
|
||||||
|
|
||||||
# Only fetch participated for a limited selection based on what had
|
# Limit fetching whether the requester has participated in a thread to
|
||||||
# summaries.
|
# events which are thread roots.
|
||||||
thread_event_ids = [
|
thread_event_ids = [
|
||||||
event_id for event_id, summary in summaries.items() if summary
|
event_id for event_id, summary in summaries.items() if summary
|
||||||
]
|
]
|
||||||
participated = await self._main_store.get_threads_participated(
|
|
||||||
thread_event_ids, user_id
|
# Pre-seed thread participation with whether the requester sent the event.
|
||||||
|
participated = {
|
||||||
|
event_id: events_by_id[event_id].sender == user_id
|
||||||
|
for event_id in thread_event_ids
|
||||||
|
}
|
||||||
|
# For events the requester did not send, check the database for whether
|
||||||
|
# the requester sent a threaded reply.
|
||||||
|
participated.update(
|
||||||
|
await self._main_store.get_threads_participated(
|
||||||
|
[
|
||||||
|
event_id
|
||||||
|
for event_id in thread_event_ids
|
||||||
|
if not participated[event_id]
|
||||||
|
],
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Then subtract off the results for any ignored users.
|
# Then subtract off the results for any ignored users.
|
||||||
|
@ -343,7 +358,8 @@ class RelationsHandler:
|
||||||
count=thread_count,
|
count=thread_count,
|
||||||
# If there's a thread summary it must also exist in the
|
# If there's a thread summary it must also exist in the
|
||||||
# participated dictionary.
|
# participated dictionary.
|
||||||
current_user_participated=participated[event_id],
|
current_user_participated=events_by_id[event_id].sender == user_id
|
||||||
|
or participated[event_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -401,9 +417,9 @@ class RelationsHandler:
|
||||||
# events to be fetched. Thus, we check those first!
|
# events to be fetched. Thus, we check those first!
|
||||||
|
|
||||||
# Fetch thread summaries (but only for the directly requested events).
|
# Fetch thread summaries (but only for the directly requested events).
|
||||||
threads = await self.get_threads_for_events(
|
threads = await self._get_threads_for_events(
|
||||||
# It is not valid to start a thread on an event which itself relates to another event.
|
events_by_id,
|
||||||
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
|
relations_by_id,
|
||||||
user_id,
|
user_id,
|
||||||
ignored_users,
|
ignored_users,
|
||||||
)
|
)
|
||||||
|
|
|
@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
relation_type: str,
|
relation_type: str,
|
||||||
assertion_callable: Callable[[JsonDict], None],
|
assertion_callable: Callable[[JsonDict], None],
|
||||||
expected_db_txn_for_event: int,
|
expected_db_txn_for_event: int,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Makes requests to various endpoints which should include bundled aggregations
|
Makes requests to various endpoints which should include bundled aggregations
|
||||||
|
@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
for relation-specific assertions.
|
for relation-specific assertions.
|
||||||
expected_db_txn_for_event: The number of database transactions which
|
expected_db_txn_for_event: The number of database transactions which
|
||||||
are expected for a call to /event/.
|
are expected for a call to /event/.
|
||||||
|
access_token: The access token to user, defaults to self.user_token.
|
||||||
"""
|
"""
|
||||||
|
access_token = access_token or self.user_token
|
||||||
|
|
||||||
def assert_bundle(event_json: JsonDict) -> None:
|
def assert_bundle(event_json: JsonDict) -> None:
|
||||||
"""Assert the expected values of the bundled aggregations."""
|
"""Assert the expected values of the bundled aggregations."""
|
||||||
|
@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/rooms/{self.room}/event/{self.parent_id}",
|
f"/rooms/{self.room}/event/{self.parent_id}",
|
||||||
access_token=self.user_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.json_body)
|
self.assertEqual(200, channel.code, channel.json_body)
|
||||||
assert_bundle(channel.json_body)
|
assert_bundle(channel.json_body)
|
||||||
|
@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/rooms/{self.room}/messages?dir=b",
|
f"/rooms/{self.room}/messages?dir=b",
|
||||||
access_token=self.user_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.json_body)
|
self.assertEqual(200, channel.code, channel.json_body)
|
||||||
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
|
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
|
||||||
|
@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/rooms/{self.room}/context/{self.parent_id}",
|
f"/rooms/{self.room}/context/{self.parent_id}",
|
||||||
access_token=self.user_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.json_body)
|
self.assertEqual(200, channel.code, channel.json_body)
|
||||||
assert_bundle(channel.json_body["event"])
|
assert_bundle(channel.json_body["event"])
|
||||||
|
@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
# Request sync.
|
# Request sync.
|
||||||
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
|
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET", f"/sync?filter={filter}", access_token=self.user_token
|
"GET", f"/sync?filter={filter}", access_token=access_token
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.json_body)
|
self.assertEqual(200, channel.code, channel.json_body)
|
||||||
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
||||||
|
@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
"/search",
|
"/search",
|
||||||
# Search term matches the parent message.
|
# Search term matches the parent message.
|
||||||
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
|
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
|
||||||
access_token=self.user_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
self.assertEqual(200, channel.code, channel.json_body)
|
self.assertEqual(200, channel.code, channel.json_body)
|
||||||
chunk = [
|
chunk = [
|
||||||
|
@ -1037,13 +1040,24 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
"""
|
"""
|
||||||
Test that threads get correctly bundled.
|
Test that threads get correctly bundled.
|
||||||
"""
|
"""
|
||||||
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
# The root message is from "user", send replies as "user2".
|
||||||
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
|
self._send_relation(
|
||||||
|
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||||
|
)
|
||||||
|
channel = self._send_relation(
|
||||||
|
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||||
|
)
|
||||||
thread_2 = channel.json_body["event_id"]
|
thread_2 = channel.json_body["event_id"]
|
||||||
|
|
||||||
|
# This needs two assertion functions which are identical except for whether
|
||||||
|
# the current_user_participated flag is True, create a factory for the
|
||||||
|
# two versions.
|
||||||
|
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
|
||||||
def assert_thread(bundled_aggregations: JsonDict) -> None:
|
def assert_thread(bundled_aggregations: JsonDict) -> None:
|
||||||
self.assertEqual(2, bundled_aggregations.get("count"))
|
self.assertEqual(2, bundled_aggregations.get("count"))
|
||||||
self.assertTrue(bundled_aggregations.get("current_user_participated"))
|
self.assertEqual(
|
||||||
|
participated, bundled_aggregations.get("current_user_participated")
|
||||||
|
)
|
||||||
# The latest thread event has some fields that don't matter.
|
# The latest thread event has some fields that don't matter.
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
{
|
{
|
||||||
|
@ -1054,13 +1068,32 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"event_id": thread_2,
|
"event_id": thread_2,
|
||||||
"sender": self.user_id,
|
"sender": self.user2_id,
|
||||||
"type": "m.room.test",
|
"type": "m.room.test",
|
||||||
},
|
},
|
||||||
bundled_aggregations.get("latest_event"),
|
bundled_aggregations.get("latest_event"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
|
return assert_thread
|
||||||
|
|
||||||
|
# The "user" sent the root event and is making queries for the bundled
|
||||||
|
# aggregations: they have participated.
|
||||||
|
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
|
||||||
|
# The "user2" sent replies in the thread and is making queries for the
|
||||||
|
# bundled aggregations: they have participated.
|
||||||
|
#
|
||||||
|
# Note that this re-uses some cached values, so the total number of
|
||||||
|
# queries is much smaller.
|
||||||
|
self._test_bundled_aggregations(
|
||||||
|
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# A user with no interactions with the thread: they have not participated.
|
||||||
|
user3_id, user3_token = self._create_user("charlie")
|
||||||
|
self.helper.join(self.room, user=user3_id, tok=user3_token)
|
||||||
|
self._test_bundled_aggregations(
|
||||||
|
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
|
||||||
|
)
|
||||||
|
|
||||||
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
|
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||||
bundled_aggregations["latest_event"].get("unsigned"),
|
bundled_aggregations["latest_event"].get("unsigned"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
|
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
|
||||||
|
|
||||||
def test_nested_thread(self) -> None:
|
def test_nested_thread(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue