Merge branch 'develop' into rav/no_more_refresh_tokens
This commit is contained in:
commit
dc4b23e1a1
|
@ -427,7 +427,7 @@ to install using pip and a virtualenv::
|
||||||
|
|
||||||
virtualenv env
|
virtualenv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
python synapse/python_dependencies.py | xargs pip install
|
||||||
pip install setuptools_trial mock
|
pip install setuptools_trial mock
|
||||||
|
|
||||||
This will run a process of downloading and installing all the needed
|
This will run a process of downloading and installing all the needed
|
||||||
|
@ -650,4 +650,3 @@ matrix.org on. The default setting is currently 0.1, which is probably
|
||||||
around a ~700MB footprint. You can dial it down further to 0.02 if
|
around a ~700MB footprint. You can dial it down further to 0.02 if
|
||||||
desired, which targets roughly ~512MB. Conversely you can dial it up if
|
desired, which targets roughly ~512MB. Conversely you can dial it up if
|
||||||
you need performance for lots of users and have a box with a lot of RAM.
|
you need performance for lots of users and have a box with a lot of RAM.
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,6 @@ tox -e py27 --notest -v
|
||||||
|
|
||||||
TOX_BIN=$TOX_DIR/py27/bin
|
TOX_BIN=$TOX_DIR/py27/bin
|
||||||
$TOX_BIN/pip install setuptools
|
$TOX_BIN/pip install setuptools
|
||||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
{ python synapse/python_dependencies.py
|
||||||
$TOX_BIN/pip install lxml
|
echo lxml psycopg2
|
||||||
$TOX_BIN/pip install psycopg2
|
} | xargs $TOX_BIN/pip install
|
||||||
|
|
|
@ -39,6 +39,9 @@ AuthEventTypes = (
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# guests always get this device id.
|
||||||
|
GUEST_DEVICE_ID = "guest_device"
|
||||||
|
|
||||||
|
|
||||||
class Auth(object):
|
class Auth(object):
|
||||||
"""
|
"""
|
||||||
|
@ -717,7 +720,8 @@ class Auth(object):
|
||||||
"user": user,
|
"user": user,
|
||||||
"is_guest": True,
|
"is_guest": True,
|
||||||
"token_id": None,
|
"token_id": None,
|
||||||
"device_id": None,
|
# all guests get the same device id
|
||||||
|
"device_id": GUEST_DEVICE_ID,
|
||||||
}
|
}
|
||||||
elif rights == "delete_pusher":
|
elif rights == "delete_pusher":
|
||||||
# We don't store these tokens in the database
|
# We don't store these tokens in the database
|
||||||
|
@ -790,9 +794,6 @@ class Auth(object):
|
||||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
type_string(str): The kind of token required (e.g. "access", "refresh",
|
||||||
"delete_pusher")
|
"delete_pusher")
|
||||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||||
This should really always be True, but there exist access tokens
|
|
||||||
in the wild which expire when they should not, so we can't
|
|
||||||
enforce expiry yet.
|
|
||||||
user_id (str): The user_id required
|
user_id (str): The user_id required
|
||||||
"""
|
"""
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
|
@ -805,11 +806,24 @@ class Auth(object):
|
||||||
v.satisfy_exact("type = " + type_string)
|
v.satisfy_exact("type = " + type_string)
|
||||||
v.satisfy_exact("user_id = %s" % user_id)
|
v.satisfy_exact("user_id = %s" % user_id)
|
||||||
v.satisfy_exact("guest = true")
|
v.satisfy_exact("guest = true")
|
||||||
|
|
||||||
|
# verify_expiry should really always be True, but there exist access
|
||||||
|
# tokens in the wild which expire when they should not, so we can't
|
||||||
|
# enforce expiry yet (so we have to allow any caveat starting with
|
||||||
|
# 'time < ' in access tokens).
|
||||||
|
#
|
||||||
|
# On the other hand, short-term login tokens (as used by CAS login, for
|
||||||
|
# example) have an expiry time which we do want to enforce.
|
||||||
|
|
||||||
if verify_expiry:
|
if verify_expiry:
|
||||||
v.satisfy_general(self._verify_expiry)
|
v.satisfy_general(self._verify_expiry)
|
||||||
else:
|
else:
|
||||||
v.satisfy_general(lambda c: c.startswith("time < "))
|
v.satisfy_general(lambda c: c.startswith("time < "))
|
||||||
|
|
||||||
|
# access_tokens and refresh_tokens include a nonce for uniqueness: any
|
||||||
|
# value is acceptable
|
||||||
|
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||||
|
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def _verify_expiry(self, caveat):
|
def _verify_expiry(self, caveat):
|
||||||
|
|
|
@ -32,7 +32,6 @@ class RegistrationConfig(Config):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||||
|
@ -55,11 +54,6 @@ class RegistrationConfig(Config):
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
||||||
# Sets the expiry for the short term user creation in
|
|
||||||
# milliseconds. For instance the bellow duration is two weeks
|
|
||||||
# in milliseconds.
|
|
||||||
user_creation_max_duration: 1209600000
|
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
|
|
|
@ -526,14 +526,15 @@ class AuthHandler(BaseHandler):
|
||||||
device_id)
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None,
|
def generate_access_token(self, user_id, extra_caveats=None):
|
||||||
duration_in_ms=(60 * 60 * 1000)):
|
|
||||||
extra_caveats = extra_caveats or []
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
now = self.hs.get_clock().time_msec()
|
# Include a nonce, to make sure that each login gets a different
|
||||||
expiry = now + duration_in_ms
|
# access token.
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("nonce = %s" % (
|
||||||
|
stringutils.random_string_with_symbols(16),
|
||||||
|
))
|
||||||
for caveat in extra_caveats:
|
for caveat in extra_caveats:
|
||||||
macaroon.add_first_party_caveat(caveat)
|
macaroon.add_first_party_caveat(caveat)
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
|
@ -369,7 +369,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
|
def get_or_create_user(self, requester, localpart, displayname,
|
||||||
password_hash=None):
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
@ -399,8 +399,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
token = self.auth_handler().generate_access_token(
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
user_id, None, duration_in_ms)
|
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
|
|
|
@ -33,6 +33,7 @@ from synapse.api.errors import (
|
||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
|
import cgi
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
@ -292,12 +293,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
@ -342,12 +338,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
|
|
||||||
|
@ -400,12 +391,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
# We need to update the transactions table to say it was sent?
|
# We need to update the transactions table to say it was sent?
|
||||||
c_type = response.headers.getRawHeaders("Content-Type")
|
check_content_type_is_json(response.headers)
|
||||||
|
|
||||||
if "application/json" not in c_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Content-Type not application/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
body = yield preserve_context_over_fn(readBody, response)
|
body = yield preserve_context_over_fn(readBody, response)
|
||||||
|
|
||||||
|
@ -525,3 +511,29 @@ def _flatten_response_never_received(e):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return "%s: %s" % (type(e).__name__, e.message,)
|
return "%s: %s" % (type(e).__name__, e.message,)
|
||||||
|
|
||||||
|
|
||||||
|
def check_content_type_is_json(headers):
|
||||||
|
"""
|
||||||
|
Check that a set of HTTP headers have a Content-Type header, and that it
|
||||||
|
is application/json.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers (twisted.web.http_headers.Headers): headers to check
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if the
|
||||||
|
|
||||||
|
"""
|
||||||
|
c_type = headers.getRawHeaders("Content-Type")
|
||||||
|
if c_type is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No Content-Type header"
|
||||||
|
)
|
||||||
|
|
||||||
|
c_type = c_type[0] # only the first header
|
||||||
|
val, options = cgi.parse_header(c_type)
|
||||||
|
if val != "application/json":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Content-Type not application/json: was '%s'" % c_type
|
||||||
|
)
|
||||||
|
|
|
@ -384,7 +384,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(CreateUserRestServlet, self).__init__(hs)
|
super(CreateUserRestServlet, self).__init__(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -418,18 +417,8 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
if "displayname" not in user_json:
|
if "displayname" not in user_json:
|
||||||
raise SynapseError(400, "Expected 'displayname' key.")
|
raise SynapseError(400, "Expected 'displayname' key.")
|
||||||
|
|
||||||
if "duration_seconds" not in user_json:
|
|
||||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
|
||||||
|
|
||||||
localpart = user_json["localpart"].encode("utf-8")
|
localpart = user_json["localpart"].encode("utf-8")
|
||||||
displayname = user_json["displayname"].encode("utf-8")
|
displayname = user_json["displayname"].encode("utf-8")
|
||||||
duration_seconds = 0
|
|
||||||
try:
|
|
||||||
duration_seconds = int(user_json["duration_seconds"])
|
|
||||||
except ValueError:
|
|
||||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
|
||||||
if duration_seconds > self.direct_user_creation_max_duration:
|
|
||||||
duration_seconds = self.direct_user_creation_max_duration
|
|
||||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||||
if user_json.get("password_hash") else None
|
if user_json.get("password_hash") else None
|
||||||
|
|
||||||
|
@ -438,7 +427,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
requester=requester,
|
requester=requester,
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_in_ms=(duration_seconds * 1000),
|
|
||||||
password_hash=password_hash
|
password_hash=password_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class DevicesRestServlet(servlet.RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
devices = yield self.device_handler.get_devices_by_user(
|
devices = yield self.device_handler.get_devices_by_user(
|
||||||
requester.user.to_string()
|
requester.user.to_string()
|
||||||
)
|
)
|
||||||
|
@ -63,7 +63,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, device_id):
|
def on_GET(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
device = yield self.device_handler.get_device(
|
device = yield self.device_handler.get_device(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
device_id,
|
device_id,
|
||||||
|
@ -99,7 +99,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, device_id):
|
def on_PUT(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
body = servlet.parse_json_object_from_request(request)
|
body = servlet.parse_json_object_from_request(request)
|
||||||
yield self.device_handler.update_device(
|
yield self.device_handler.update_device(
|
||||||
|
|
|
@ -65,7 +65,7 @@ class KeyUploadServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class KeyQueryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id):
|
def on_POST(self, request, user_id, device_id):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
result = yield self.e2e_keys_handler.query_devices(body, timeout)
|
||||||
|
@ -158,7 +158,7 @@ class KeyQueryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id):
|
def on_GET(self, request, user_id, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
auth_user_id = requester.user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
user_id = user_id if user_id else auth_user_id
|
user_id = user_id if user_id else auth_user_id
|
||||||
|
@ -204,7 +204,7 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id, algorithm):
|
def on_GET(self, request, user_id, device_id, algorithm):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
{"one_time_keys": {user_id: {device_id: algorithm}}},
|
||||||
|
@ -214,7 +214,7 @@ class OneTimeKeyServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id, device_id, algorithm):
|
def on_POST(self, request, user_id, device_id, algorithm):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse
|
||||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
|
@ -100,12 +101,14 @@ class RegisterRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
kind = "user"
|
kind = "user"
|
||||||
if "kind" in request.args:
|
if "kind" in request.args:
|
||||||
kind = request.args["kind"][0]
|
kind = request.args["kind"][0]
|
||||||
|
|
||||||
if kind == "guest":
|
if kind == "guest":
|
||||||
ret = yield self._do_guest_registration()
|
ret = yield self._do_guest_registration(body)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
return
|
return
|
||||||
elif kind != "user":
|
elif kind != "user":
|
||||||
|
@ -113,8 +116,6 @@ class RegisterRestServlet(RestServlet):
|
||||||
"Do not understand membership kind: %s" % (kind,)
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
)
|
)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
|
||||||
|
|
||||||
# we do basic sanity checks here because the auth layer will store these
|
# we do basic sanity checks here because the auth layer will store these
|
||||||
# in sessions. Pull out the username/password provided to us.
|
# in sessions. Pull out the username/password provided to us.
|
||||||
desired_password = None
|
desired_password = None
|
||||||
|
@ -420,13 +421,22 @@ class RegisterRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self):
|
def _do_guest_registration(self, params):
|
||||||
if not self.hs.config.allow_guest_access:
|
if not self.hs.config.allow_guest_access:
|
||||||
defer.returnValue((403, "Guest access is disabled"))
|
defer.returnValue((403, "Guest access is disabled"))
|
||||||
user_id, _ = yield self.registration_handler.register(
|
user_id, _ = yield self.registration_handler.register(
|
||||||
generate_token=False,
|
generate_token=False,
|
||||||
make_guest=True
|
make_guest=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# we don't allow guests to specify their own device_id, because
|
||||||
|
# we have nowhere to store it.
|
||||||
|
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||||
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
|
self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
|
||||||
access_token = self.auth_handler.generate_access_token(
|
access_token = self.auth_handler.generate_access_token(
|
||||||
user_id, ["guest = true"]
|
user_id, ["guest = true"]
|
||||||
)
|
)
|
||||||
|
@ -434,6 +444,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
# so long as we don't return a refresh_token here.
|
# so long as we don't return a refresh_token here.
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -50,7 +50,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _put(self, request, message_type, txn_id):
|
def _put(self, request, message_type, txn_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
|
|
@ -61,14 +61,14 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def verify_type(caveat):
|
def verify_type(caveat):
|
||||||
return caveat == "type = access"
|
return caveat == "type = access"
|
||||||
|
|
||||||
def verify_expiry(caveat):
|
def verify_nonce(caveat):
|
||||||
return caveat == "time < 8600000"
|
return caveat.startswith("nonce =")
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
v.satisfy_general(verify_gen)
|
v.satisfy_general(verify_gen)
|
||||||
v.satisfy_general(verify_user)
|
v.satisfy_general(verify_user)
|
||||||
v.satisfy_general(verify_type)
|
v.satisfy_general(verify_type)
|
||||||
v.satisfy_general(verify_expiry)
|
v.satisfy_general(verify_nonce)
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
|
|
|
@ -53,13 +53,12 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
duration_ms = 200
|
|
||||||
local_part = "someone"
|
local_part = "someone"
|
||||||
display_name = "someone"
|
display_name = "someone"
|
||||||
user_id = "@someone:test"
|
user_id = "@someone:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name, duration_ms)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
|
@ -71,12 +70,11 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
user_id=frank.to_string(),
|
user_id=frank.to_string(),
|
||||||
token="jkv;g498752-43gj['eamb!-5",
|
token="jkv;g498752-43gj['eamb!-5",
|
||||||
password_hash=None)
|
password_hash=None)
|
||||||
duration_ms = 200
|
|
||||||
local_part = "frank"
|
local_part = "frank"
|
||||||
display_name = "Frank"
|
display_name = "Frank"
|
||||||
user_id = "@frank:test"
|
user_id = "@frank:test"
|
||||||
requester = create_requester("@as:test")
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
requester, local_part, display_name, duration_ms)
|
requester, local_part, display_name)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
Loading…
Reference in New Issue