Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.17.0

This commit is contained in:
Erik Johnston 2016-08-08 11:24:53 +01:00
commit 3410142741
13 changed files with 125 additions and 69 deletions

View File

@ -143,6 +143,7 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -165,7 +165,7 @@ def start(config_options):
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )

View File

@ -285,7 +285,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = get_version_string("Synapse", synapse) version_string = "Synapse/" + get_version_string(synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)

View File

@ -273,7 +273,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )

View File

@ -424,7 +424,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(), application_service_handler=SynchrotronApplicationService(),
) )

View File

@ -236,9 +236,9 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
e = self._get_pdu_cache.get(event_id) ev = self._get_pdu_cache.get(event_id)
if e: if ev:
defer.returnValue(e) defer.returnValue(ev)
pdu = None pdu = None
for destination in destinations: for destination in destinations:
@ -269,7 +269,7 @@ class FederationClient(FederationBase):
break break
except SynapseError: except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
@ -336,8 +336,10 @@ class FederationClient(FederationBase):
ev.event_id: ev for ev in fetched_events ev.event_id: ev for ev in fetched_events
} }
pdus = [event_map[e_id] for e_id in state_event_ids] pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids] auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -523,14 +525,19 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException: except CodeMessageException as e:
raise if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -602,8 +609,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException: except CodeMessageException as e:
raise if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",

View File

@ -18,13 +18,14 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -67,7 +78,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request, content):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@ -75,17 +86,10 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
content = None if content is not None:
origin = None json_request["content"] = content
if request.method in ["PUT", "POST"]: origin = None
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
@ -103,14 +107,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -121,7 +125,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -130,10 +134,12 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue(origin)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name, def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler): room_list_handler):
self.handler = handler self.handler = handler
@ -141,29 +147,46 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, func):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(code) @functools.wraps(func)
def new_code(request, *args, **kwargs): def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
(origin, content) = yield authenticator.authenticate_request(request) origin = yield authenticator.authenticate_request(request, content)
with ratelimiter.ratelimit(origin) as d: except NoAuthenticationError:
yield d origin = None
response = yield code( if self.REQUIRE_AUTH:
origin, content, request.args, *args, **kwargs logger.exception("authenticate_request failed")
) raise
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__ new_func.__self__ = func.__self__
return new_code return new_func
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@ -429,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, origin, content, query):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@ -453,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@ -478,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, origin, content, query):
token = parse_string(request, "access_token") token = query.get("access_token", [None])[0]
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@ -497,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet): class PublicRoomList(BaseFederationServlet):
""" """
@ -542,6 +558,20 @@ class PublicRoomList(BaseFederationServlet):
defer.returnValue((200, data)) defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -565,6 +595,7 @@ SERVLET_CLASSES = (
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList, PublicRoomList,
FederationVersionServlet,
) )

View File

@ -68,9 +68,18 @@ class Metrics(object):
def register_memory_metrics(hs): def register_memory_metrics(hs):
metric = MemoryUsageMetric(hs) try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric) all_metrics.append(metric)
return metric
def get_metrics_for(pkg_name): def get_metrics_for(pkg_name):

View File

@ -16,8 +16,6 @@
from itertools import chain from itertools import chain
import psutil
# TODO(paul): I can't believe Python doesn't have one of these # TODO(paul): I can't believe Python doesn't have one of these
def map_concat(func, items): def map_concat(func, items):
@ -167,9 +165,10 @@ class MemoryUsageMetric(object):
UPDATE_HZ = 2 # number of times to get memory per second UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs): def __init__(self, hs, psutil):
clock = hs.get_clock() clock = hs.get_clock()
self.memory_snapshots = [] self.memory_snapshots = []
self.process = psutil.Process() self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ) clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)

View File

@ -36,7 +36,6 @@ REQUIREMENTS = {
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"psutil>=2.0.0": ["psutil>=2.0.0"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
@ -52,6 +51,9 @@ CONDITIONAL_REQUIREMENTS = {
"ldap": { "ldap": {
"ldap3>=1.0": ["ldap3>=1.0"], "ldap3>=1.0": ["ldap3>=1.0"],
}, },
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
} }

View File

@ -345,7 +345,8 @@ class PreviewUrlResource(Resource):
# lines) # lines)
text_nodes = ( text_nodes = (
re.sub(r'\s+', '\n', el.text).strip() re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter() if el.text for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
) )
og['og:description'] = summarize_paragraphs(text_nodes) og['og:description'] = summarize_paragraphs(text_nodes)

View File

@ -350,7 +350,7 @@ class EventsStore(SQLBaseStore):
) )
if not events and not allow_none: if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,)) raise SynapseError(404, "Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None) defer.returnValue(events[0] if events else None)

View File

@ -21,7 +21,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_version_string(name, module): def get_version_string(module):
try: try:
null = open(os.devnull, 'w') null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(module.__file__)) cwd = os.path.dirname(os.path.abspath(module.__file__))
@ -74,11 +74,11 @@ def get_version_string(name, module):
) )
return ( return (
"%s/%s (%s)" % ( "%s (%s)" % (
name, module.__version__, git_version, module.__version__, git_version,
) )
).encode("ascii") ).encode("ascii")
except Exception as e: except Exception as e:
logger.info("Failed to check for git repository: %s", e) logger.info("Failed to check for git repository: %s", e)
return ("%s/%s" % (name, module.__version__,)).encode("ascii") return module.__version__.encode("ascii")