Split registration store

This commit is contained in:
Erik Johnston 2018-03-01 18:19:34 +00:00
parent 1a6c7cdf54
commit fafa3e7114
2 changed files with 64 additions and 72 deletions

View File

@ -14,20 +14,8 @@
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage.registration import RegistrationWorkerStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore): class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): pass
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View File

@ -19,10 +19,70 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(background_updates.BackgroundUpdateStore): class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)
defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
return None
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs) super(RegistrationStore, self).__init__(db_conn, hs)
@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
) )
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id): def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash. Returns a mapping of user_id -> password_hash.
@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)
@cached()
def get_user_by_access_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
)
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)
defer.returnValue(res if res else False)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user_id): def is_guest(self, user_id):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", { yield self._simple_upsert("user_threepids", {