Clean up types for PaginationConfig (#8250)

This removes `SourcePaginationConfig` and `get_pagination_rows`. The reasoning behind this is that these generic classes/functions erased the types of the IDs it used (i.e. instead of passing around `StreamToken` it'd pass in e.g. `token.room_key`, which don't have uniform types).
This commit is contained in:
Erik Johnston 2020-09-08 15:00:17 +01:00 committed by GitHub
parent 703e2b8a96
commit 0f545e6b96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 86 deletions

1
changelog.d/8250.misc Normal file
View File

@ -0,0 +1 @@
Clean up type hints for `PaginationConfig`.

View File

@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"] presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token) presence, _ = await presence_stream.get_new_events(
presence, _ = await presence_stream.get_pagination_rows( user, from_key=None, include_offline=False
user, pagination_config.get_source_config("presence"), None
) )
receipt_stream = self.hs.get_event_sources().sources["receipt"] joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt, _ = await receipt_stream.get_pagination_rows( receipt = await self.store.get_linearized_receipts_for_rooms(
user, pagination_config.get_source_config("receipt"), None joined_rooms, to_key=int(now_token.receipt_key),
) )
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)

View File

@ -335,20 +335,16 @@ class PaginationHandler:
user_id = requester.user.to_string() user_id = requester.user.to_string()
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key from_token = pagin_config.from_token
else: else:
pagin_config.from_token = ( from_token = self.hs.get_event_sources().get_current_token_for_pagination()
self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token) if pagin_config.limit is None:
# This shouldn't happen as we've set a default limit before this
# gets called.
raise Exception("limit not set")
pagin_config.from_token = pagin_config.from_token.copy_and_replace( room_token = RoomStreamToken.parse(from_token.room_key)
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
with await self.pagination_lock.read(room_id): with await self.pagination_lock.read(room_id):
( (
@ -358,7 +354,7 @@ class PaginationHandler:
room_id, user_id, allow_departed_users=True room_id, user_id, allow_departed_users=True
) )
if source_config.direction == "b": if pagin_config.direction == "b":
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
# requires that we have a topo token. # requires that we have a topo token.
if room_token.topological: if room_token.topological:
@ -381,22 +377,28 @@ class PaginationHandler:
member_event_id member_event_id
) )
if RoomStreamToken.parse(leave_token).topological < max_topo: if RoomStreamToken.parse(leave_token).topological < max_topo:
source_config.from_key = str(leave_token) from_token = from_token.copy_and_replace(
"room_key", leave_token
)
await self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo room_id, max_topo
) )
to_room_key = None
if pagin_config.to_token:
to_room_key = pagin_config.to_token.room_key
events, next_key = await self.store.paginate_room_events( events, next_key = await self.store.paginate_room_events(
room_id=room_id, room_id=room_id,
from_key=source_config.from_key, from_key=from_token.room_key,
to_key=source_config.to_key, to_key=to_room_key,
direction=source_config.direction, direction=pagin_config.direction,
limit=source_config.limit, limit=pagin_config.limit,
event_filter=event_filter, event_filter=event_filter,
) )
next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) next_token = from_token.copy_and_replace("room_key", next_key)
if events: if events:
if event_filter: if event_filter:
@ -409,7 +411,7 @@ class PaginationHandler:
if not events: if not events:
return { return {
"chunk": [], "chunk": [],
"start": pagin_config.from_token.to_string(), "start": from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
} }
@ -438,7 +440,7 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event events, time_now, as_client_event=as_client_event
) )
), ),
"start": pagin_config.from_token.to_string(), "start": from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
} }

View File

@ -1108,9 +1108,6 @@ class PresenceEventSource:
def get_current_key(self): def get_current_key(self):
return self.store.get_current_presence_token() return self.store.get_current_presence_token()
async def get_pagination_rows(self, user, pagination_config, key):
return await self.get_new_events(user, from_key=None, include_offline=False)
@cached(num_args=2, cache_context=True) @cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context): async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence """Returns the set of users that the given user should see presence

View File

@ -142,18 +142,3 @@ class ReceiptEventSource:
def get_current_key(self, direction="f"): def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()
async def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
room_ids = await self.store.get_rooms_for_user(user.to_string())
events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
return (events, to_key)

View File

@ -432,8 +432,9 @@ class Notifier:
If explicit_room_id is set, that room will be polled for events only if If explicit_room_id is set, that room will be polled for events only if
it is world readable or the user has joined the room. it is world readable or the user has joined the room.
""" """
if pagination_config.from_token:
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: else:
from_token = self.event_sources.get_current_token() from_token = self.event_sources.get_current_token()
limit = pagination_config.limit limit = pagination_config.limit

View File

@ -14,9 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
import attr
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import StreamToken from synapse.types import StreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,38 +29,22 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000 MAX_LIMIT = 1000
class SourcePaginationConfig: @attr.s(slots=True)
"""A configuration object which stores pagination parameters for a
specific event source."""
def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
self.from_key = from_key
self.to_key = to_key
self.direction = "f" if direction == "f" else "b"
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self):
return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
self.from_key,
self.to_key,
self.direction,
self.limit,
)
class PaginationConfig: class PaginationConfig:
"""A configuration object which stores pagination parameters.""" """A configuration object which stores pagination parameters."""
def __init__(self, from_token=None, to_token=None, direction="f", limit=None): from_token = attr.ib(type=Optional[StreamToken])
self.from_token = from_token to_token = attr.ib(type=Optional[StreamToken])
self.to_token = to_token direction = attr.ib(type=str)
self.direction = "f" if direction == "f" else "b" limit = attr.ib(type=Optional[int])
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True, default_limit=None): def from_request(
cls,
request: SynapseRequest,
raise_invalid_params: bool = True,
default_limit: Optional[int] = None,
) -> "PaginationConfig":
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from") from_tok = parse_string(request, "from")
@ -78,29 +66,22 @@ class PaginationConfig:
limit = parse_integer(request, "limit", default=default_limit) limit = parse_integer(request, "limit", default=default_limit)
if limit and limit < 0: if limit:
if limit < 0:
raise SynapseError(400, "Limit must be 0 or above") raise SynapseError(400, "Limit must be 0 or above")
limit = min(int(limit), MAX_LIMIT)
try: try:
return PaginationConfig(from_tok, to_tok, direction, limit) return PaginationConfig(from_tok, to_tok, direction, limit)
except Exception: except Exception:
logger.exception("Failed to create pagination config") logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.") raise SynapseError(400, "Invalid request.")
def __repr__(self): def __repr__(self) -> str:
return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
self.from_token, self.from_token,
self.to_token, self.to_token,
self.direction, self.direction,
self.limit, self.limit,
) )
def get_source_config(self, source_name):
keyname = "%s_key" % source_name
return SourcePaginationConfig(
from_key=getattr(self.from_token, keyname),
to_key=getattr(self.to_token, keyname) if self.to_token else None,
direction=self.direction,
limit=self.limit,
)