Merge pull request #817 from matrix-org/dbkr/split_out_auth_handler
Split out the auth handler
This commit is contained in:
commit
fb2193cc63
|
@ -24,7 +24,6 @@ from .federation import FederationHandler
|
||||||
from .profile import ProfileHandler
|
from .profile import ProfileHandler
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
from .admin import AdminHandler
|
from .admin import AdminHandler
|
||||||
from .auth import AuthHandler
|
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
from .receipts import ReceiptsHandler
|
from .receipts import ReceiptsHandler
|
||||||
from .search import SearchHandler
|
from .search import SearchHandler
|
||||||
|
@ -50,7 +49,6 @@ class Handlers(object):
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
self.receipts_handler = ReceiptsHandler(hs)
|
self.receipts_handler = ReceiptsHandler(hs)
|
||||||
self.auth_handler = AuthHandler(hs)
|
|
||||||
self.identity_handler = IdentityHandler(hs)
|
self.identity_handler = IdentityHandler(hs)
|
||||||
self.search_handler = SearchHandler(hs)
|
self.search_handler = SearchHandler(hs)
|
||||||
self.room_context_handler = RoomContextHandler(hs)
|
self.room_context_handler = RoomContextHandler(hs)
|
||||||
|
|
|
@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
def auth_handler(self):
|
def auth_handler(self):
|
||||||
return self.hs.get_handlers().auth_handler
|
return self.hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||||
|
|
|
@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self.servername = hs.config.server_name
|
self.servername = hs.config.server_name
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
flows = []
|
flows = []
|
||||||
|
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_id, self.hs.hostname
|
user_id, self.hs.hostname
|
||||||
).to_string()
|
).to_string()
|
||||||
|
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"])
|
||||||
|
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_token_login(self, login_submission):
|
def do_token_login(self, login_submission):
|
||||||
token = login_submission['token']
|
token = login_submission['token']
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = (
|
user_id = (
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
|
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if user_exists:
|
if user_exists:
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
|
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if user_exists:
|
if user_exists:
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
|
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if not user_exists:
|
if not user_exists:
|
||||||
user_id, _ = (
|
user_id, _ = (
|
||||||
|
|
|
@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
super(PasswordRestServlet, self).__init__()
|
super(PasswordRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
|
@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
||||||
super(AuthRestServlet, self).__init__()
|
super(AuthRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
old_refresh_token = body["refresh_token"]
|
old_refresh_token = body["refresh_token"]
|
||||||
auth_handler = self.hs.get_handlers().auth_handler
|
auth_handler = self.hs.get_auth_handler()
|
||||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||||
old_refresh_token, auth_handler.generate_refresh_token)
|
old_refresh_token, auth_handler.generate_refresh_token)
|
||||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||||
|
|
|
@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
from synapse.handlers.room import RoomListHandler
|
from synapse.handlers.room import RoomListHandler
|
||||||
|
from synapse.handlers.auth import AuthHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.state import StateHandler
|
from synapse.state import StateHandler
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
|
@ -89,6 +90,7 @@ class HomeServer(object):
|
||||||
'sync_handler',
|
'sync_handler',
|
||||||
'typing_handler',
|
'typing_handler',
|
||||||
'room_list_handler',
|
'room_list_handler',
|
||||||
|
'auth_handler',
|
||||||
'application_service_api',
|
'application_service_api',
|
||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
|
@ -190,6 +192,9 @@ class HomeServer(object):
|
||||||
def build_room_list_handler(self):
|
def build_room_list_handler(self):
|
||||||
return RoomListHandler(self)
|
return RoomListHandler(self)
|
||||||
|
|
||||||
|
def build_auth_handler(self):
|
||||||
|
return AuthHandler(self)
|
||||||
|
|
||||||
def build_application_service_api(self):
|
def build_application_service_api(self):
|
||||||
return ApplicationServiceApi(self)
|
return ApplicationServiceApi(self)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# do the dance to hook it up to the hs global
|
# do the dance to hook it up to the hs global
|
||||||
self.handlers = Mock(
|
self.handlers = Mock(
|
||||||
auth_handler=self.auth_handler,
|
|
||||||
registration_handler=self.registration_handler,
|
registration_handler=self.registration_handler,
|
||||||
identity_handler=self.identity_handler,
|
identity_handler=self.identity_handler,
|
||||||
login_handler=self.login_handler
|
login_handler=self.login_handler
|
||||||
|
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.hostname = "superbig~testing~thing.com"
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
|
|
|
@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
# bcrypt is far too slow to be doing in unit tests
|
# bcrypt is far too slow to be doing in unit tests
|
||||||
def swap_out_hash_for_testing(old_build_handlers):
|
# Need to let the HS build an auth handler and then mess with it
|
||||||
def build_handlers():
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||||
handlers = old_build_handlers()
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||||
auth_handler = handlers.auth_handler
|
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
|
||||||
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
|
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||||
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
|
||||||
return handlers
|
|
||||||
return build_handlers
|
|
||||||
|
|
||||||
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
|
||||||
|
|
||||||
fed = kargs.get("resource_for_federation", None)
|
fed = kargs.get("resource_for_federation", None)
|
||||||
if fed:
|
if fed:
|
||||||
|
|
Loading…
Reference in New Issue