Pass the Requester down to the HttpTransactionCache. (#15200)
This commit is contained in:
parent
820f02b70b
commit
47bc84dd53
|
@ -0,0 +1 @@
|
|||
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
|
@ -23,10 +23,10 @@ from synapse.http.servlet import (
|
|||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin import assert_requester_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet):
|
|||
self.__class__.__name__,
|
||||
)
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, txn_id: Optional[str] = None
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
await assert_user_is_admin(self.auth, requester)
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ("user_id", "content"))
|
||||
event_type = body.get("type", EventTypes.Message)
|
||||
|
@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet):
|
|||
|
||||
return HTTPStatus.OK, {"event_id": event.event_id}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return await self._do(request, requester, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, txn_id
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
set_tag("txn_id", txn_id)
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester, txn_id
|
||||
)
|
||||
|
|
|
@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.cancellation import cancellable
|
||||
|
@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
def on_PUT(
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
set_tag("txn_id", txn_id)
|
||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester
|
||||
)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return await self._do(request, requester)
|
||||
|
||||
async def _do(
|
||||
self, request: SynapseRequest, requester: Requester
|
||||
) -> Tuple[int, JsonDict]:
|
||||
room_id, _, _ = await self._room_creation_handler.create_room(
|
||||
requester, self.get_room_config(request)
|
||||
)
|
||||
|
@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
|
||||
|
||||
# TODO: Needs unit testing for generic events
|
||||
class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
class RoomStateEventRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
super().__init__()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.message_handler = hs.get_message_handler()
|
||||
|
@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict: JsonDict = {
|
||||
|
@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_GET(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||
) -> Tuple[int, str]:
|
||||
return 200, "Not implemented"
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id, event_type, None)
|
||||
|
||||
def on_PUT(
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_type, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
event_type,
|
||||
txn_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_identifier: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
|
||||
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
|
||||
|
@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
|
|||
|
||||
return 200, {"room_id": room_id}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_identifier: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_identifier, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_identifier: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_identifier, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester, room_identifier, txn_id
|
||||
)
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class PublicRoomListRestServlet(TransactionRestServlet):
|
||||
class PublicRoomListRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/publicRooms$", v1=True)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
|
@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
|
||||
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
|
||||
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
return await self._do(requester, room_id)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, requester, room_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
membership_action: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE,
|
||||
|
@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
|
||||
return 200, return_value
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
membership_action: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
return await self._do(request, requester, room_id, membership_action, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, membership_action, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._do,
|
||||
request,
|
||||
requester,
|
||||
room_id,
|
||||
membership_action,
|
||||
txn_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
async def on_POST(
|
||||
async def _do(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
txn_id: Optional[str] = None,
|
||||
txn_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
|
@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_PUT(
|
||||
async def on_POST(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
return await self._do(request, requester, room_id, event_id, None)
|
||||
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, room_id, event_id, txn_id
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request, requester, self._do, request, requester, room_id, event_id, txn_id
|
||||
)
|
||||
|
||||
|
||||
|
@ -1224,7 +1280,6 @@ def register_txn_path(
|
|||
servlet: RestServlet,
|
||||
regex_string: str,
|
||||
http_server: HttpServer,
|
||||
with_get: bool = False,
|
||||
) -> None:
|
||||
"""Registers a transaction-based path.
|
||||
|
||||
|
@ -1236,7 +1291,6 @@ def register_txn_path(
|
|||
regex_string: The regex string to register. Must NOT have a
|
||||
trailing $ as this string will be appended to.
|
||||
http_server: The http_server to register paths with.
|
||||
with_get: True to also register respective GET paths for the PUTs.
|
||||
"""
|
||||
on_POST = getattr(servlet, "on_POST", None)
|
||||
on_PUT = getattr(servlet, "on_PUT", None)
|
||||
|
@ -1254,18 +1308,6 @@ def register_txn_path(
|
|||
on_PUT,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
on_GET = getattr(servlet, "on_GET", None)
|
||||
if with_get:
|
||||
if on_GET is None:
|
||||
raise RuntimeError(
|
||||
"register_txn_path called with with_get = True, but no on_GET method exists"
|
||||
)
|
||||
http_server.register_paths(
|
||||
"GET",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
on_GET,
|
||||
servlet.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class TimestampLookupRestServlet(RestServlet):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http import servlet
|
||||
from synapse.http.server import HttpServer
|
||||
|
@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r
|
|||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, Requester
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
|
@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
self.txns = HttpTransactionCache(hs)
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
|
||||
def on_PUT(
|
||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
set_tag("txn_id", txn_id)
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self._put, request, message_type, txn_id
|
||||
)
|
||||
|
||||
async def _put(
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
set_tag("txn_id", txn_id)
|
||||
return await self.txns.fetch_or_execute_request(
|
||||
request,
|
||||
requester,
|
||||
self._put,
|
||||
request,
|
||||
requester,
|
||||
message_type,
|
||||
)
|
||||
|
||||
async def _put(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
requester: Requester,
|
||||
message_type: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(content, ("messages",))
|
||||
|
||||
|
|
|
@ -15,16 +15,16 @@
|
|||
"""This module contains logic for storing HTTP PUT transactions. This is used
|
||||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.server import Request
|
||||
from twisted.web.iweb import IRequest
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -41,53 +41,47 @@ P = ParamSpec("P")
|
|||
class HttpTransactionCache:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = self.hs.get_auth()
|
||||
self.clock = self.hs.get_clock()
|
||||
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
||||
self.transactions: Dict[
|
||||
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
|
||||
Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
|
||||
] = {}
|
||||
# Try to clean entries every 30 mins. This means entries will exist
|
||||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||
|
||||
def _get_transaction_key(self, request: Request) -> str:
|
||||
def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
|
||||
"""A helper function which returns a transaction key that can be used
|
||||
with TransactionCache for idempotent requests.
|
||||
|
||||
Idempotency is based on the returned key being the same for separate
|
||||
requests to the same endpoint. The key is formed from the HTTP request
|
||||
path and the access_token for the requesting user.
|
||||
path and attributes from the requester: the access_token_id for regular users,
|
||||
the user ID for guest users, and the appservice ID for appservice users.
|
||||
|
||||
Args:
|
||||
request: The incoming request. Must contain an access_token.
|
||||
request: The incoming request.
|
||||
requester: The requester doing the request.
|
||||
Returns:
|
||||
A transaction key
|
||||
"""
|
||||
assert request.path is not None
|
||||
token = self.auth.get_access_token_from_request(request)
|
||||
return request.path.decode("utf8") + "/" + token
|
||||
path: str = request.path.decode("utf8")
|
||||
if requester.is_guest:
|
||||
assert requester.user is not None, "Guest requester must have a user ID set"
|
||||
return (path, "guest", requester.user)
|
||||
elif requester.app_service is not None:
|
||||
return (path, "appservice", requester.app_service.id)
|
||||
else:
|
||||
assert (
|
||||
requester.access_token_id is not None
|
||||
), "Requester must have an access_token_id"
|
||||
return (path, "user", requester.access_token_id)
|
||||
|
||||
def fetch_or_execute_request(
|
||||
self,
|
||||
request: Request,
|
||||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
"""A helper function for fetch_or_execute which extracts
|
||||
a transaction key from the given request.
|
||||
|
||||
See:
|
||||
fetch_or_execute
|
||||
"""
|
||||
return self.fetch_or_execute(
|
||||
self._get_transaction_key(request), fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def fetch_or_execute(
|
||||
self,
|
||||
txn_key: str,
|
||||
request: IRequest,
|
||||
requester: Requester,
|
||||
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
|
@ -96,14 +90,15 @@ class HttpTransactionCache:
|
|||
to produce a response for this transaction.
|
||||
|
||||
Args:
|
||||
txn_key: A key to ensure idempotency should fetch_or_execute be
|
||||
called again at a later point in time.
|
||||
request:
|
||||
requester:
|
||||
fn: A function which returns a tuple of (response_code, response_dict).
|
||||
*args: Arguments to pass to fn.
|
||||
**kwargs: Keyword arguments to pass to fn.
|
||||
Returns:
|
||||
Deferred which resolves to a tuple of (response_code, response_dict).
|
||||
"""
|
||||
txn_key = self._get_transaction_key(request, requester)
|
||||
if txn_key in self.transactions:
|
||||
observable = self.transactions[txn_key][0]
|
||||
else:
|
||||
|
|
|
@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
self.cache = HttpTransactionCache(self.hs)
|
||||
|
||||
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
|
||||
self.mock_key = "foo"
|
||||
|
||||
# Here we make sure that we're setting all the fields that HttpTransactionCache
|
||||
# uses to build the transaction key.
|
||||
self.mock_request = Mock()
|
||||
self.mock_request.path = b"/foo/bar"
|
||||
self.mock_requester = Mock()
|
||||
self.mock_requester.app_service = None
|
||||
self.mock_requester.is_guest = False
|
||||
self.mock_requester.access_token_id = 1234
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_executes_given_function(
|
||||
self,
|
||||
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg"
|
||||
res = yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
|
||||
)
|
||||
cb.assert_called_once_with("some_arg", keyword="arg")
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
|
@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
) -> Generator["defer.Deferred[Any]", object, None]:
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
for i in range(3): # invoke multiple times
|
||||
res = yield self.cache.fetch_or_execute(
|
||||
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
|
||||
res = yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request,
|
||||
self.mock_requester,
|
||||
cb,
|
||||
"some_arg",
|
||||
keyword="arg",
|
||||
changing_args=i,
|
||||
)
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
# expect only a single call to do the work
|
||||
|
@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test() -> Generator["defer.Deferred[Any]", object, None]:
|
||||
with LoggingContext("c") as c1:
|
||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
||||
res = yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb
|
||||
)
|
||||
self.assertIs(current_context(), c1)
|
||||
self.assertEqual(res, (1, {}))
|
||||
|
||||
|
@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
|
||||
with LoggingContext("test") as test_context:
|
||||
try:
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb)
|
||||
yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb
|
||||
)
|
||||
except Exception as e:
|
||||
self.assertEqual(e.args[0], "boo")
|
||||
self.assertIs(current_context(), test_context)
|
||||
|
||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
||||
res = yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb
|
||||
)
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
self.assertIs(current_context(), test_context)
|
||||
|
||||
|
@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
|
||||
with LoggingContext("test") as test_context:
|
||||
try:
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb)
|
||||
yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb
|
||||
)
|
||||
except Exception as e:
|
||||
self.assertEqual(e.args[0], "boo")
|
||||
self.assertIs(current_context(), test_context)
|
||||
|
||||
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
|
||||
res = yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb
|
||||
)
|
||||
self.assertEqual(res, self.mock_http_response)
|
||||
self.assertIs(current_context(), test_context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||
cb = Mock(return_value=make_awaitable(self.mock_http_response))
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
||||
yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb, "an arg"
|
||||
)
|
||||
# should NOT have cleaned up yet
|
||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
|
||||
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
||||
yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb, "an arg"
|
||||
)
|
||||
# still using cache
|
||||
cb.assert_called_once_with("an arg")
|
||||
|
||||
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
|
||||
|
||||
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
|
||||
yield self.cache.fetch_or_execute_request(
|
||||
self.mock_request, self.mock_requester, cb, "an arg"
|
||||
)
|
||||
# no longer using cache
|
||||
self.assertEqual(cb.call_count, 2)
|
||||
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
|
||||
|
|
Loading…
Reference in New Issue