Do not convert async functions to Deferreds in the interactive_auth_handler (#7944)
This commit is contained in:
parent
5ea29d7f85
commit
53f7b49f5b
|
@ -0,0 +1 @@
|
|||
Convert the interactive_auth_handler wrapper to async/await.
|
|
@ -17,8 +17,7 @@
|
|||
"""
|
||||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Iterable, Pattern
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.urls import CLIENT_API_PREFIX
|
||||
|
@ -27,15 +26,23 @@ from synapse.types import JsonDict
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
|
||||
def client_patterns(
|
||||
path_regex: str,
|
||||
releases: Iterable[int] = (0,),
|
||||
unstable: bool = True,
|
||||
v1: bool = False,
|
||||
) -> Iterable[Pattern]:
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
Args:
|
||||
path_regex (str): The regex string to match. This should NOT have a ^
|
||||
path_regex: The regex string to match. This should NOT have a ^
|
||||
as this will be prefixed.
|
||||
releases: An iterable of releases to include this endpoint under.
|
||||
unstable: If true, include this endpoint under the "unstable" prefix.
|
||||
v1: If true, include this endpoint under the "api/v1" prefix.
|
||||
Returns:
|
||||
SRE_Pattern
|
||||
An iterable of patterns.
|
||||
"""
|
||||
patterns = []
|
||||
|
||||
|
@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
|
|||
def interactive_auth_handler(orig):
|
||||
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||
|
||||
Takes a on_POST method which returns a deferred (errcode, body) response
|
||||
Takes a on_POST method which returns an Awaitable (errcode, body) response
|
||||
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||
a 401 response.
|
||||
|
||||
Normal usage is:
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
async def on_POST(self, request):
|
||||
# ...
|
||||
yield self.auth_handler.check_auth
|
||||
"""
|
||||
await self.auth_handler.check_auth
|
||||
"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
res = defer.ensureDeferred(orig(*args, **kwargs))
|
||||
res.addErrback(_catch_incomplete_interactive_auth)
|
||||
return res
|
||||
async def wrapped(*args, **kwargs):
|
||||
try:
|
||||
return await orig(*args, **kwargs)
|
||||
except InteractiveAuthIncompleteError as e:
|
||||
return 401, e.result
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _catch_incomplete_interactive_auth(f):
|
||||
"""helper for interactive_auth_handler
|
||||
|
||||
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
||||
|
||||
Args:
|
||||
f (failure.Failure):
|
||||
"""
|
||||
f.trap(InteractiveAuthIncompleteError)
|
||||
return 401, f.value.result
|
||||
|
|
Loading…
Reference in New Issue