Merge branch 'develop' into dbkr/email_notifs_on_pusher
This commit is contained in:
commit
1f71f386f6
|
@ -612,7 +612,8 @@ class Auth(object):
|
|||
def get_user_from_macaroon(self, macaroon_str):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
self.validate_macaroon(macaroon, "access", False)
|
||||
|
||||
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
|
|
|
@ -57,6 +57,8 @@ class KeyConfig(Config):
|
|||
seed = self.signing_key[0].seed
|
||||
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||
**kwargs):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
@ -69,6 +71,9 @@ class KeyConfig(Config):
|
|||
return """\
|
||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||
|
||||
# Used to enable access token expiration.
|
||||
expire_access_token: False
|
||||
|
||||
## Signing Keys ##
|
||||
|
||||
# Path to the signing key to sign messages with
|
||||
|
|
|
@ -32,6 +32,7 @@ class RegistrationConfig(Config):
|
|||
)
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||
|
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
|||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
||||
# Sets the expiry for the short term user creation in
|
||||
# milliseconds. For instance the bellow duration is two weeks
|
||||
# in milliseconds.
|
||||
user_creation_max_duration: 1209600000
|
||||
|
||||
# Set the number of bcrypt rounds used to generate password hash.
|
||||
# Larger numbers increase the work factor needed to generate the hash.
|
||||
# The default number of rounds is 12.
|
||||
|
|
|
@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
|
|||
))
|
||||
return m.serialize()
|
||||
|
||||
def generate_short_term_login_token(self, user_id):
|
||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (2 * 60 * 1000)
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
|
|
|
@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
|
|||
)
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
"""Creates a new user or returns an access token for an existing one
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
yield self.check_username(localpart)
|
||||
except SynapseError as e:
|
||||
if e.errcode == Codes.USER_IN_USE:
|
||||
need_register = False
|
||||
else:
|
||||
raise
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.flush_user(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
yield profile_handler.set_displayname(
|
||||
user, user, displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_handlers().auth_handler
|
||||
|
||||
|
|
|
@ -164,8 +164,8 @@ class ReplicationResource(Resource):
|
|||
"Replicating %d rows of %s from %s -> %s",
|
||||
len(stream_content["rows"]),
|
||||
stream_name,
|
||||
stream_content["position"],
|
||||
request_streams.get(stream_name),
|
||||
stream_content["position"],
|
||||
)
|
||||
|
||||
request.write(json.dumps(result, ensure_ascii=False))
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage.account_data import AccountDataStore
|
||||
|
||||
|
||||
class SlavedAccountDataStore(BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
|
||||
self._account_data_id_gen = SlavedIdTracker(
|
||||
db_conn, "account_data_max_stream_id", "stream_id",
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_users = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_user = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
result["user_account_data"] = position
|
||||
result["room_account_data"] = position
|
||||
result["tag_account_data"] = position
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("user_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
user_id, data_type = row[1:3]
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(data_type, user_id,)
|
||||
)
|
||||
|
||||
stream = result.get("room_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
|
||||
stream = result.get("tag_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
|
@ -165,12 +165,14 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
|
||||
stream = result.get("forward_ex_outliers")
|
||||
if stream:
|
||||
self._stream_id_gen.advance(stream["position"])
|
||||
for row in stream["rows"]:
|
||||
event_id = row[1]
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
stream = result.get("backward_ex_outliers")
|
||||
if stream:
|
||||
self._backfill_id_gen.advance(-stream["position"])
|
||||
for row in stream["rows"]:
|
||||
event_id = row[1]
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
|
|
@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
|
||||
|
||||
class CreateUserRestServlet(ClientV1RestServlet):
|
||||
"""Handles user creation via a server-to-server interface
|
||||
"""
|
||||
|
||||
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user_json = parse_json_object_from_request(request)
|
||||
|
||||
if "access_token" not in request.args:
|
||||
raise SynapseError(400, "Expected application service token.")
|
||||
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request.args["access_token"][0]
|
||||
)
|
||||
if not app_service:
|
||||
raise SynapseError(403, "Invalid application service token.")
|
||||
|
||||
logger.debug("creating user: %s", user_json)
|
||||
|
||||
response = yield self._do_create(user_json)
|
||||
|
||||
defer.returnValue((200, response))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return 403, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_create(self, user_json):
|
||||
yield run_on_reactor()
|
||||
|
||||
if "localpart" not in user_json:
|
||||
raise SynapseError(400, "Expected 'localpart' key.")
|
||||
|
||||
if "displayname" not in user_json:
|
||||
raise SynapseError(400, "Expected 'displayname' key.")
|
||||
|
||||
if "duration_seconds" not in user_json:
|
||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||
|
||||
localpart = user_json["localpart"].encode("utf-8")
|
||||
displayname = user_json["displayname"].encode("utf-8")
|
||||
duration_seconds = 0
|
||||
try:
|
||||
duration_seconds = int(user_json["duration_seconds"])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
CreateUserRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -453,7 +453,9 @@ class SQLBaseStore(object):
|
|||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): key/values to use when inserting
|
||||
Returns: A deferred
|
||||
Returns:
|
||||
Deferred(bool): True if a new entry was created, False if an
|
||||
existing one was updated.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
|
@ -498,6 +500,10 @@ class SQLBaseStore(object):
|
|||
)
|
||||
txn.execute(sql, allvalues.values())
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
|
|
|
@ -224,6 +224,18 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
(room_id, event_id)
|
||||
)
|
||||
|
||||
def _remove_push_actions_before_txn(self, txn, room_id, user_id,
|
||||
topological_ordering):
|
||||
txn.call_after(
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||
(room_id, user_id, )
|
||||
)
|
||||
txn.execute(
|
||||
"DELETE FROM event_push_actions"
|
||||
" WHERE room_id = ? AND user_id = ? AND topological_ordering < ?",
|
||||
(room_id, user_id, topological_ordering,)
|
||||
)
|
||||
|
||||
|
||||
def _action_has_highlight(actions):
|
||||
for action in actions:
|
||||
|
|
|
@ -156,8 +156,7 @@ class PusherStore(SQLBaseStore):
|
|||
profile_tag=""):
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
def f(txn):
|
||||
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
||||
return self._simple_upsert_txn(
|
||||
newly_inserted = self._simple_upsert_txn(
|
||||
txn,
|
||||
"pushers",
|
||||
{
|
||||
|
@ -178,11 +177,18 @@ class PusherStore(SQLBaseStore):
|
|||
"id": stream_id,
|
||||
},
|
||||
)
|
||||
defer.returnValue((yield self.runInteraction("add_pusher", f)))
|
||||
if newly_inserted:
|
||||
# get_users_with_pushers_in_room only cares if the user has
|
||||
# at least *one* pusher.
|
||||
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
||||
|
||||
yield self.runInteraction("add_pusher", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||
def delete_pusher_txn(txn, stream_id):
|
||||
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
||||
|
||||
self._simple_delete_one_txn(
|
||||
txn,
|
||||
"pushers",
|
||||
|
@ -194,6 +200,7 @@ class PusherStore(SQLBaseStore):
|
|||
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
|
||||
{"stream_id": stream_id},
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
yield self.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
|
|
|
@ -100,7 +100,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue([ev for res in results.values() for ev in res])
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, max_entries=5000)
|
||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
|
@ -232,7 +232,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
)
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
|
||||
txn.call_after(
|
||||
self._receipts_stream_cache.entity_has_changed,
|
||||
|
@ -244,6 +244,17 @@ class ReceiptsStore(SQLBaseStore):
|
|||
(user_id, room_id, receipt_type)
|
||||
)
|
||||
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
allow_none=True
|
||||
)
|
||||
|
||||
topological_ordering = int(res["topological_ordering"]) if res else None
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
sql = (
|
||||
|
@ -255,16 +266,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
results = txn.fetchall()
|
||||
|
||||
if results:
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
)
|
||||
topological_ordering = int(res["topological_ordering"])
|
||||
stream_ordering = int(res["stream_ordering"])
|
||||
|
||||
if results and topological_ordering:
|
||||
for to, so, _ in results:
|
||||
if int(to) > topological_ordering:
|
||||
return False
|
||||
|
@ -294,6 +296,14 @@ class ReceiptsStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
if receipt_type == "m.read" and topological_ordering:
|
||||
self._remove_push_actions_before_txn(
|
||||
txn,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
topological_ordering=topological_ordering,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -367,7 +377,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
)
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
-- The following indices are redundant, other indices are equivalent or
|
||||
-- supersets
|
||||
DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
|
||||
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
|
||||
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
|
||||
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
|
||||
DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
|
||||
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
|
||||
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
|
||||
|
||||
DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT
|
||||
DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT
|
||||
|
||||
-- The following indices were unused
|
||||
DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id;
|
||||
DROP INDEX IF EXISTS evauth_edges_auth_id;
|
||||
DROP INDEX IF EXISTS presence_stream_state;
|
|
@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
|
|||
macaroon.add_first_party_caveat("time < 1") # ms
|
||||
|
||||
self.hs.clock.now = 5000 # seconds
|
||||
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
self.hs.config.expire_access_token = True
|
||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||
# validate expiration (and remove the above line, which will start
|
||||
# throwing).
|
||||
# with self.assertRaises(AuthError) as cm:
|
||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
# self.assertEqual(401, cm.exception.code)
|
||||
# self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
with self.assertRaises(AuthError) as cm:
|
||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||
self.assertEqual(401, cm.exception.code)
|
||||
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from .. import unittest
|
||||
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class RegistrationHandlers(object):
|
||||
def __init__(self, hs):
|
||||
self.registration_handler = RegistrationHandler(hs)
|
||||
|
||||
|
||||
class RegistrationTestCase(unittest.TestCase):
|
||||
""" Tests the RegistrationHandler. """
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_distributor = Mock()
|
||||
self.mock_distributor.declare("registered_user")
|
||||
self.mock_captcha_client = Mock()
|
||||
hs = yield setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True)
|
||||
hs.handlers = RegistrationHandlers(hs)
|
||||
self.handler = hs.get_handlers().registration_handler
|
||||
hs.get_handlers().profile_handler = Mock()
|
||||
self.mock_handler = Mock(spec=[
|
||||
"generate_short_term_login_token",
|
||||
])
|
||||
|
||||
hs.get_handlers().auth_handler = self.mock_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
"""
|
||||
Returns:
|
||||
The user doess not exist in this case so it will register and log it in
|
||||
"""
|
||||
duration_ms = 200
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
user_id = "@someone:test"
|
||||
mock_token = self.mock_handler.generate_short_term_login_token
|
||||
mock_token.return_value = 'secret'
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
local_part, display_name, duration_ms)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
self.assertEquals(result_token, 'secret')
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
USER_ID = "@feeling:blue"
|
||||
TYPE = "my.type"
|
||||
|
||||
|
||||
class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
STORE_TYPE = SlavedAccountDataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_account_data(self):
|
||||
yield self.master_store.add_account_data_for_user(
|
||||
USER_ID, TYPE, {"a": 1}
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_global_account_data_by_type_for_user",
|
||||
[TYPE, USER_ID], {"a": 1}
|
||||
)
|
||||
yield self.check(
|
||||
"get_global_account_data_by_type_for_users",
|
||||
[TYPE, [USER_ID]], {USER_ID: {"a": 1}}
|
||||
)
|
||||
|
||||
yield self.master_store.add_account_data_for_user(
|
||||
USER_ID, TYPE, {"a": 2}
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_global_account_data_by_type_for_user",
|
||||
[TYPE, USER_ID], {"a": 2}
|
||||
)
|
||||
yield self.check(
|
||||
"get_global_account_data_by_type_for_users",
|
||||
[TYPE, [USER_ID]], {USER_ID: {"a": 2}}
|
||||
)
|
|
@ -0,0 +1,88 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||
from twisted.internet import defer
|
||||
from mock import Mock
|
||||
from tests import unittest
|
||||
import json
|
||||
|
||||
|
||||
class CreateUserServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/client/api/v1/createUser'
|
||||
)
|
||||
self.request.args = {}
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: defer.succeed(self.appservice))
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None, None)
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "supergbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.config.enable_registration = True
|
||||
# init the thing we're testing
|
||||
self.servlet = CreateUserRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_createuser_with_valid_user(self):
|
||||
user_id = "@someone:interesting"
|
||||
token = "my token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200
|
||||
})
|
||||
|
||||
self.registration_handler.get_or_create_user = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
|
@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
config.event_cache_size = 1
|
||||
config.enable_registration = True
|
||||
config.macaroon_secret_key = "not even a little secret"
|
||||
config.expire_access_token = False
|
||||
config.server_name = "server.under.test"
|
||||
config.trusted_third_party_id_servers = []
|
||||
config.room_invite_state_types = []
|
||||
|
|
Loading…
Reference in New Issue