Do not convert async functions to Deferreds in the interactive_auth_handler (#7944)
This commit is contained in:
parent
5ea29d7f85
commit
53f7b49f5b
|
@ -0,0 +1 @@
|
||||||
|
Convert the interactive_auth_handler wrapper to async/await.
|
|
@ -17,8 +17,7 @@
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Iterable, Pattern
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||||
from synapse.api.urls import CLIENT_API_PREFIX
|
from synapse.api.urls import CLIENT_API_PREFIX
|
||||||
|
@ -27,15 +26,23 @@ from synapse.types import JsonDict
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
|
def client_patterns(
|
||||||
|
path_regex: str,
|
||||||
|
releases: Iterable[int] = (0,),
|
||||||
|
unstable: bool = True,
|
||||||
|
v1: bool = False,
|
||||||
|
) -> Iterable[Pattern]:
|
||||||
"""Creates a regex compiled client path with the correct client path
|
"""Creates a regex compiled client path with the correct client path
|
||||||
prefix.
|
prefix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path_regex (str): The regex string to match. This should NOT have a ^
|
path_regex: The regex string to match. This should NOT have a ^
|
||||||
as this will be prefixed.
|
as this will be prefixed.
|
||||||
|
releases: An iterable of releases to include this endpoint under.
|
||||||
|
unstable: If true, include this endpoint under the "unstable" prefix.
|
||||||
|
v1: If true, include this endpoint under the "api/v1" prefix.
|
||||||
Returns:
|
Returns:
|
||||||
SRE_Pattern
|
An iterable of patterns.
|
||||||
"""
|
"""
|
||||||
patterns = []
|
patterns = []
|
||||||
|
|
||||||
|
@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
|
||||||
def interactive_auth_handler(orig):
|
def interactive_auth_handler(orig):
|
||||||
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||||
|
|
||||||
Takes a on_POST method which returns a deferred (errcode, body) response
|
Takes a on_POST method which returns an Awaitable (errcode, body) response
|
||||||
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||||
a 401 response.
|
a 401 response.
|
||||||
|
|
||||||
Normal usage is:
|
Normal usage is:
|
||||||
|
|
||||||
@interactive_auth_handler
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request):
|
||||||
def on_POST(self, request):
|
|
||||||
# ...
|
# ...
|
||||||
yield self.auth_handler.check_auth
|
await self.auth_handler.check_auth
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapped(*args, **kwargs):
|
async def wrapped(*args, **kwargs):
|
||||||
res = defer.ensureDeferred(orig(*args, **kwargs))
|
try:
|
||||||
res.addErrback(_catch_incomplete_interactive_auth)
|
return await orig(*args, **kwargs)
|
||||||
return res
|
except InteractiveAuthIncompleteError as e:
|
||||||
|
return 401, e.result
|
||||||
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def _catch_incomplete_interactive_auth(f):
|
|
||||||
"""helper for interactive_auth_handler
|
|
||||||
|
|
||||||
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
|
||||||
|
|
||||||
Args:
|
|
||||||
f (failure.Failure):
|
|
||||||
"""
|
|
||||||
f.trap(InteractiveAuthIncompleteError)
|
|
||||||
return 401, f.value.result
|
|
||||||
|
|
Loading…
Reference in New Issue