Use an enum for direction. (#14927)
For better type safety we use an enum instead of strings to configure direction (backwards or forwards).
This commit is contained in:
parent
fc35e0673f
commit
265735db9d
|
@ -0,0 +1 @@
|
|||
Add missing type hints.
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
"""Contains constants from the specification."""
|
||||
|
||||
import enum
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
# the max size of a (canonical-json-encoded) event
|
||||
|
@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
|
|||
|
||||
NONE = "org.matrix.msc3866.none"
|
||||
EMAIL = "org.matrix.msc3866.email"
|
||||
|
||||
|
||||
class Direction(enum.Enum):
|
||||
BACKWARDS = "b"
|
||||
FORWARDS = "f"
|
||||
|
|
|
@ -16,7 +16,7 @@ import abc
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -197,7 +197,7 @@ class AdminHandler:
|
|||
# efficient method perhaps but it does guarantee we get everything.
|
||||
while True:
|
||||
events, _ = await self.store.paginate_room_events(
|
||||
room_id, from_key, to_key, limit=100, direction="f"
|
||||
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
|
|
|
@ -15,7 +15,13 @@
|
|||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
|
||||
from synapse.api.constants import (
|
||||
AccountDataTypes,
|
||||
Direction,
|
||||
EduTypes,
|
||||
EventTypes,
|
||||
Membership,
|
||||
)
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
|
@ -57,7 +63,13 @@ class InitialSyncHandler:
|
|||
self.validator = EventValidator()
|
||||
self.snapshot_cache: ResponseCache[
|
||||
Tuple[
|
||||
str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
|
||||
str,
|
||||
Optional[StreamToken],
|
||||
Optional[StreamToken],
|
||||
Direction,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
]
|
||||
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
|
|
@ -19,7 +19,7 @@ import attr
|
|||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import Direction, EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
|
@ -448,7 +448,7 @@ class PaginationHandler:
|
|||
|
||||
if pagin_config.from_token:
|
||||
from_token = pagin_config.from_token
|
||||
elif pagin_config.direction == "f":
|
||||
elif pagin_config.direction == Direction.FORWARDS:
|
||||
from_token = (
|
||||
await self.hs.get_event_sources().get_start_token_for_pagination(
|
||||
room_id
|
||||
|
@ -476,7 +476,7 @@ class PaginationHandler:
|
|||
room_id, requester, allow_departed_users=True
|
||||
)
|
||||
|
||||
if pagin_config.direction == "b":
|
||||
if pagin_config.direction == Direction.BACKWARDS:
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.constants import Direction, EventTypes, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
|
@ -413,7 +413,11 @@ class RelationsHandler:
|
|||
|
||||
# Attempt to find another event to use as the latest event.
|
||||
potential_events, _ = await self._main_store.get_relations_for_event(
|
||||
event_id, event, room_id, RelationTypes.THREAD, direction="f"
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
RelationTypes.THREAD,
|
||||
direction=Direction.FORWARDS,
|
||||
)
|
||||
|
||||
# Filter out ignored users.
|
||||
|
|
|
@ -30,7 +30,7 @@ from typing import (
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
|
||||
from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
direction: Direction = Direction.BACKWARDS,
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
|
@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
relation_type: Only fetch events with this relation type, if given.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
limit: Only fetch the most recent `limit` events.
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`).
|
||||
direction: Whether to fetch the most recent first (backwards) or the
|
||||
oldest first (forwards).
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ from typing_extensions import Literal
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
|
@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
|
|||
_STREAM_TOKEN = "stream"
|
||||
_TOPOLOGICAL_TOKEN = "topological"
|
||||
|
||||
|
||||
# Used as return values for pagination APIs
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _EventDictReturn:
|
||||
|
@ -104,7 +104,7 @@ class _EventsAround:
|
|||
|
||||
|
||||
def generate_pagination_where_clause(
|
||||
direction: str,
|
||||
direction: Direction,
|
||||
column_names: Tuple[str, str],
|
||||
from_token: Optional[Tuple[Optional[int], int]],
|
||||
to_token: Optional[Tuple[Optional[int], int]],
|
||||
|
@ -130,27 +130,26 @@ def generate_pagination_where_clause(
|
|||
token, but include those that match the to token.
|
||||
|
||||
Args:
|
||||
direction: Whether we're paginating backwards("b") or forwards ("f").
|
||||
direction: Whether we're paginating backwards or forwards.
|
||||
column_names: The column names to bound. Must *not* be user defined as
|
||||
these get inserted directly into the SQL statement without escapes.
|
||||
from_token: The start point for the pagination. This is an exclusive
|
||||
minimum bound if direction is "f", and an inclusive maximum bound if
|
||||
direction is "b".
|
||||
minimum bound if direction is forwards, and an inclusive maximum bound if
|
||||
direction is backwards.
|
||||
to_token: The endpoint point for the pagination. This is an inclusive
|
||||
maximum bound if direction is "f", and an exclusive minimum bound if
|
||||
direction is "b".
|
||||
maximum bound if direction is forwards, and an exclusive minimum bound if
|
||||
direction is backwards.
|
||||
engine: The database engine to generate the clauses for
|
||||
|
||||
Returns:
|
||||
The sql expression
|
||||
"""
|
||||
assert direction in ("b", "f")
|
||||
|
||||
where_clause = []
|
||||
if from_token:
|
||||
where_clause.append(
|
||||
_make_generic_sql_bound(
|
||||
bound=">=" if direction == "b" else "<",
|
||||
bound=">=" if direction == Direction.BACKWARDS else "<",
|
||||
column_names=column_names,
|
||||
values=from_token,
|
||||
engine=engine,
|
||||
|
@ -160,7 +159,7 @@ def generate_pagination_where_clause(
|
|||
if to_token:
|
||||
where_clause.append(
|
||||
_make_generic_sql_bound(
|
||||
bound="<" if direction == "b" else ">=",
|
||||
bound="<" if direction == Direction.BACKWARDS else ">=",
|
||||
column_names=column_names,
|
||||
values=to_token,
|
||||
engine=engine,
|
||||
|
@ -171,7 +170,7 @@ def generate_pagination_where_clause(
|
|||
|
||||
|
||||
def generate_pagination_bounds(
|
||||
direction: str,
|
||||
direction: Direction,
|
||||
from_token: Optional[RoomStreamToken],
|
||||
to_token: Optional[RoomStreamToken],
|
||||
) -> Tuple[
|
||||
|
@ -181,7 +180,7 @@ def generate_pagination_bounds(
|
|||
Generate a start and end point for this page of events.
|
||||
|
||||
Args:
|
||||
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
|
||||
direction: Whether pagination is going forwards or backwards.
|
||||
from_token: The token to start pagination at, or None to start at the first value.
|
||||
to_token: The token to end pagination at, or None to not limit the end point.
|
||||
|
||||
|
@ -201,7 +200,7 @@ def generate_pagination_bounds(
|
|||
# Tokens really represent positions between elements, but we use
|
||||
# the convention of pointing to the event before the gap. Hence
|
||||
# we have a bit of asymmetry when it comes to equalities.
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -215,7 +214,7 @@ def generate_pagination_bounds(
|
|||
if from_token:
|
||||
if from_token.topological is not None:
|
||||
from_bound = from_token.as_historical_tuple()
|
||||
elif direction == "b":
|
||||
elif direction == Direction.BACKWARDS:
|
||||
from_bound = (
|
||||
None,
|
||||
from_token.get_max_stream_pos(),
|
||||
|
@ -230,7 +229,7 @@ def generate_pagination_bounds(
|
|||
if to_token:
|
||||
if to_token.topological is not None:
|
||||
to_bound = to_token.as_historical_tuple()
|
||||
elif direction == "b":
|
||||
elif direction == Direction.BACKWARDS:
|
||||
to_bound = (
|
||||
None,
|
||||
to_token.stream,
|
||||
|
@ -245,20 +244,20 @@ def generate_pagination_bounds(
|
|||
|
||||
|
||||
def generate_next_token(
|
||||
direction: str, last_topo_ordering: int, last_stream_ordering: int
|
||||
direction: Direction, last_topo_ordering: int, last_stream_ordering: int
|
||||
) -> RoomStreamToken:
|
||||
"""
|
||||
Generate the next room stream token based on the currently returned data.
|
||||
|
||||
Args:
|
||||
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
|
||||
direction: Whether pagination is going forwards or backwards.
|
||||
last_topo_ordering: The last topological ordering being returned.
|
||||
last_stream_ordering: The last stream ordering being returned.
|
||||
|
||||
Returns:
|
||||
A new RoomStreamToken to return to the client.
|
||||
"""
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
# Tokens are positions between events.
|
||||
# This token points *after* the last event in the chunk.
|
||||
# We need it to point to the event before it in the chunk
|
||||
|
@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
txn,
|
||||
room_id,
|
||||
before_token,
|
||||
direction="b",
|
||||
direction=Direction.BACKWARDS,
|
||||
limit=before_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
txn,
|
||||
room_id,
|
||||
after_token,
|
||||
direction="f",
|
||||
direction=Direction.FORWARDS,
|
||||
limit=after_limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
room_id: str,
|
||||
from_token: RoomStreamToken,
|
||||
to_token: Optional[RoomStreamToken] = None,
|
||||
direction: str = "b",
|
||||
direction: Direction = Direction.BACKWARDS,
|
||||
limit: int = -1,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
|
||||
|
@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
room_id
|
||||
from_token: The token used to stream from
|
||||
to_token: A token which if given limits the results to only those before
|
||||
direction: Either 'b' or 'f' to indicate whether we are paginating
|
||||
forwards or backwards from `from_key`.
|
||||
direction: Indicates whether we are paginating forwards or backwards
|
||||
from `from_key`.
|
||||
limit: The maximum number of events to return.
|
||||
event_filter: If provided filters the events to
|
||||
those that match the filter.
|
||||
|
@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
_EventDictReturn(event_id, topological_ordering, stream_ordering)
|
||||
for event_id, instance_name, topological_ordering, stream_ordering in txn
|
||||
if _filter_results(
|
||||
lower_token=to_token if direction == "b" else from_token,
|
||||
upper_token=from_token if direction == "b" else to_token,
|
||||
lower_token=to_token
|
||||
if direction == Direction.BACKWARDS
|
||||
else from_token,
|
||||
upper_token=from_token
|
||||
if direction == Direction.BACKWARDS
|
||||
else to_token,
|
||||
instance_name=instance_name,
|
||||
topological_ordering=topological_ordering,
|
||||
stream_ordering=stream_ordering,
|
||||
|
@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
room_id: str,
|
||||
from_key: RoomStreamToken,
|
||||
to_key: Optional[RoomStreamToken] = None,
|
||||
direction: str = "b",
|
||||
direction: Direction = Direction.BACKWARDS,
|
||||
limit: int = -1,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Tuple[List[EventBase], RoomStreamToken]:
|
||||
|
@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
room_id
|
||||
from_key: The token used to stream from
|
||||
to_key: A token which if given limits the results to only those before
|
||||
direction: Either 'b' or 'f' to indicate whether we are paginating
|
||||
forwards or backwards from `from_key`.
|
||||
direction: Indicates whether we are paginating forwards or backwards
|
||||
from `from_key`.
|
||||
limit: The maximum number of events to return.
|
||||
event_filter: If provided filters the events to those that match the filter.
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import Optional
|
|||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
@ -34,7 +35,7 @@ class PaginationConfig:
|
|||
|
||||
from_token: Optional[StreamToken]
|
||||
to_token: Optional[StreamToken]
|
||||
direction: str
|
||||
direction: Direction
|
||||
limit: int
|
||||
|
||||
@classmethod
|
||||
|
@ -45,9 +46,13 @@ class PaginationConfig:
|
|||
default_limit: int,
|
||||
default_dir: str = "f",
|
||||
) -> "PaginationConfig":
|
||||
direction = parse_string(
|
||||
request, "dir", default=default_dir, allowed_values=["f", "b"]
|
||||
direction_str = parse_string(
|
||||
request,
|
||||
"dir",
|
||||
default=default_dir,
|
||||
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
|
||||
)
|
||||
direction = Direction(direction_str)
|
||||
|
||||
from_tok_str = parse_string(request, "from")
|
||||
to_tok_str = parse_string(request, "to")
|
||||
|
|
Loading…
Reference in New Issue