Merge branch 'release-v0.9.2'

This commit is contained in:
Erik Johnston 2015-06-12 11:53:03 +01:00
commit 405f8c4796
33 changed files with 552 additions and 286 deletions

View File

@ -1,3 +1,26 @@
Changes in synapse v0.9.2 (2015-06-12)
======================================
General:
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26)
======================================

View File

@ -21,3 +21,5 @@ handlers:
root:
level: INFO
handlers: [journal]
disable_existing_loggers: False

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.9.1"
__version__ = "0.9.2"

View File

@ -54,6 +54,8 @@ from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
@ -85,10 +87,16 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
return gz_wrap(ClientV1RestResource(self))
res = ClientV1RestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_client_v2_alpha(self):
return gz_wrap(ClientV2AlphaRestResource(self))
res = ClientV2AlphaRestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
def build_resource_for_federation(self):
return JsonResource(self)
@ -415,6 +423,8 @@ def setup(config_options):
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:

View File

@ -26,6 +26,7 @@ class CaptchaConfig(Config):
config["captcha_ip_origin_is_x_forwarded"]
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name):
return """\
@ -48,4 +49,7 @@ class CaptchaConfig(Config):
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
"""

View File

@ -39,7 +39,7 @@ class RegistrationConfig(Config):
## Registration ##
# Enable registration for new users.
enable_registration: True
enable_registration: False
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.

View File

@ -28,6 +28,8 @@ class ServerConfig(Config):
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.gzip_responses = config["gzip_responses"]
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
@ -85,6 +87,11 @@ class ServerConfig(Config):
# Turn on the twisted telnet manhole service on localhost on the given
# port.
#manhole: 9000
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
gzip_responses: True
""" % locals()
def read_arguments(self, args):

View File

@ -16,6 +16,12 @@
from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
# a dict to frozen_dicts is expensive.
USE_FROZEN_DICTS = True
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@ -122,7 +128,10 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {}))
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
super(FrozenEvent, self).__init__(
frozen_dict,

View File

@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
@ -120,16 +118,15 @@ class FederationBase(object):
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
"Signature check failed for %s",
pdu.event_id,
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
"Event content has been tampered, redacting.",
pdu.event_id,
)
defer.returnValue(redacted_event)

View File

@ -93,6 +93,8 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
defer.returnValue((origin, content))
@log_function

View File

@ -78,7 +78,9 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = context.prev_state_events
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)

View File

@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
# each request
try:
client = SimpleHttpClient(self.hs)
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
@ -199,6 +199,7 @@ class AuthHandler(BaseHandler):
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)

View File

@ -247,9 +247,15 @@ class FederationHandler(BaseHandler):
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
# For each edge get the current state.
auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
@ -258,12 +264,46 @@ class FederationHandler(BaseHandler):
event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
yield defer.gatherResults(
[
self._handle_new_event(dest, a)
self._handle_new_event(
dest, a,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
},
)
for a in auth_events.values()
if a.event_id not in seen_events
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@ -274,6 +314,11 @@ class FederationHandler(BaseHandler):
dest, event_map[e_id],
state=events_to_state[e_id],
backfilled=True,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
},
)
for e_id in events_to_state
],
@ -900,8 +945,10 @@ class FederationHandler(BaseHandler):
event.event_id, event.signatures,
)
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state
event, old_state=state, outlier=outlier,
)
if not auth_events:
@ -912,7 +959,7 @@ class FederationHandler(BaseHandler):
event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
is_new_state = not outlier
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.

View File

@ -20,7 +20,8 @@ import synapse.metrics
from twisted.internet import defer, reactor
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError
Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
)
from twisted.web.http_headers import Headers
@ -55,7 +56,9 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(reactor)
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string
def request(self, method, *args, **kwargs):

View File

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
@ -103,7 +103,9 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor)
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.clock = hs.get_clock()
self.version_string = hs.version_string

View File

@ -19,9 +19,10 @@ from synapse.api.errors import (
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
import synapse.metrics
import synapse.events
from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json
encode_canonical_json, encode_pretty_printed_json, encode_json
)
from twisted.internet import defer
@ -168,9 +169,10 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self, hs):
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.version_string = hs.version_string
@ -256,6 +258,7 @@ class JsonResource(HttpServer, resource.Resource):
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
canonical_json=self.canonical_json,
)
@ -277,11 +280,16 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string=""):
version_string="", canonical_json=True):
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json:
json_bytes = encode_canonical_json(json_object)
else:
json_bytes = encode_json(
json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
)
return respond_with_json_bytes(
request, code, json_bytes,

View File

@ -24,6 +24,7 @@ import baserules
import logging
import simplejson as json
import re
import random
logger = logging.getLogger(__name__)
@ -256,12 +257,31 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
timeout=timeout, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
@ -273,10 +293,11 @@ class Pusher(object):
break
if not single_event:
self.last_token = chunk['end']
continue
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
return
if not self.alive:
continue
return
processed = False
actions = yield self._actions_for_event(single_event)
@ -319,7 +340,7 @@ class Pusher(object):
)
if not self.alive:
continue
return
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF

View File

@ -18,7 +18,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.6": ["syutil>=0.0.6"],
"syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -30,6 +30,7 @@ REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
@ -52,8 +53,8 @@ def github_link(project, version, egg):
DEPENDENCY_LINKS = [
github_link(
project="matrix-org/syutil",
version="v0.0.6",
egg="syutil-0.0.6",
version="v0.0.7",
egg="syutil-0.0.7",
),
github_link(
project="matrix-org/matrix-angular-sdk",

View File

@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod

View File

@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod

View File

@ -15,13 +15,14 @@
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.internet import defer, threads
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
@ -52,7 +53,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = hs.get_http_client()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@ -273,11 +274,14 @@ class BaseMediaResource(Resource):
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
@ -303,10 +307,10 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
@ -320,10 +324,15 @@ class BaseMediaResource(Resource):
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
])
yield threads.deferToThread(generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,

View File

@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
def compute_event_context(self, event, old_state=None, outlier=False):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@ -119,9 +119,23 @@ class StateHandler(object):
Returns:
an EventContext
"""
yield run_on_reactor()
context = EventContext()
yield run_on_reactor()
if outlier:
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
else:
context.current_state = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
if old_state:
context.current_state = {
@ -155,10 +169,6 @@ class StateHandler(object):
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:

View File

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 19
SCHEMA_VERSION = 20
dir_path = os.path.abspath(os.path.dirname(__file__))
@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur)
module.run_upgrade(cur, database_engine)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)

View File

@ -127,7 +127,7 @@ class Cache(object):
self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def wrap(orig):
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
name=orig.__name__,
max_entries=max_entries,
keylen=num_args,
lru=lru,
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
@functools.wraps(orig)
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(self, *keyargs):
def wrapped(*keyargs):
try:
cached_result = cache.get(*keyargs)
cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs)
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs,
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield orig(self, *keyargs)
ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs + (ret,))
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
return wrap
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object):

View File

@ -17,7 +17,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor
from synapse.events import FrozenEvent
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
@ -26,11 +26,11 @@ from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
from syutil.jsonutil import encode_json
from contextlib import contextmanager
import logging
import simplejson as json
import ujson as json
logger = logging.getLogger(__name__)
@ -166,8 +166,9 @@ class EventsStore(SQLBaseStore):
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
metadata_json = encode_json(
event.internal_metadata.get_dict(),
using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
# If we have already persisted this event, we don't need to do any
@ -235,12 +236,14 @@ class EventsStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": encode_canonical_json(event_dict).decode("UTF-8"),
"json": encode_json(
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
},
)
content = encode_canonical_json(
event.content
content = encode_json(
event.content, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
vals = {
@ -266,8 +269,8 @@ class EventsStore(SQLBaseStore):
]
}
vals["unrecognized_keys"] = encode_canonical_json(
unrec
vals["unrecognized_keys"] = encode_json(
unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = (
@ -733,7 +736,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event(
redaction_id,
check_redacted=False
check_redacted=False,
allow_none=True,
)
if because:
@ -743,6 +747,7 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]

View File

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur):
def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:

View File

@ -0,0 +1,76 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main purpose of this upgrade is to change the unique key on the
pushers table again (it was missed when the v16 full schema was
made) but this also changes the pushkey and data columns to text.
When selecting a bytea column into a text column, postgres inserts
the hex encoded data, and there's no portable way of getting the
UTF-8 bytes, so we have to do it in Python.
"""
import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data TEXT,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
""")
cur.execute("""SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
""")
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(database_engine.convert_param_style("""
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
row
)
count += 1
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)

View File

@ -81,19 +81,23 @@ class StateStore(SQLBaseStore):
f,
)
@defer.inlineCallbacks
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
state_list = yield defer.gatherResults(
[
c(vals)
for vals in states.values()
self._fetch_events_for_group(group, vals)
for group, vals in states.items()
],
consumeErrors=True,
)
defer.returnValue(states)
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, state_group, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None:

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
class JsonEncodedObject(object):
""" A common base class for defining protocol units that are represented
@ -76,15 +74,7 @@ class JsonEncodedObject(object):
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
return copy.deepcopy(d)
def get_full_dict(self):
d = {
k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys
}
d.update(self.unrecognized_keys)
return copy.deepcopy(d)
return d
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))

View File

@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None):
def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
)
self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None,
ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})

View File

@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room",
"store_room",
"get_latest_events_in_room",
"add_event_hashes",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
@defer.inlineCallbacks
def test_invite(self):

View File

@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A(object):
@cached()
def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield func(self, "bar")), "bar")
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
func.invalidate("foo")
a.func.invalidate("foo")
yield func(self, "foo")
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
class A(object):
@cached()
def func(self, key):
return key
func.invalidate("what")
A().func.invalidate("what")
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
class A(object):
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
for k in range(0,12):
yield func(self, k)
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
yield func(self, k)
yield a.func(k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
def test_prefill(self):
callcount = [0]
class A(object):
@cached()
def func(self, key):
callcount[0] += 1
return key
func.prefill("foo", 123)
a = A()
self.assertEquals((yield func(self, "foo")), 123)
a.func.prefill("foo", 123)
self.assertEquals((yield a.func("foo")), 123)
self.assertEquals(callcount[0], 0)

View File

@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id))
)
result = yield self.store.get_user_by_token(self.tokens[1])
result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset(
{