Merge branch 'develop' of github.com:matrix-org/synapse into erikj/e2e_one_time_upsert

This commit is contained in:
Erik Johnston 2017-03-29 10:57:19 +01:00
commit 4ad613f6be
45 changed files with 831 additions and 596 deletions

View File

@ -146,6 +146,7 @@ To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate source ~/.synapse/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools pip install --upgrade setuptools
pip install https://github.com/matrix-org/synapse/tarball/master pip install https://github.com/matrix-org/synapse/tarball/master
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
New user localpart: erikj New user localpart: erikj
Password: Password:
Confirm password: Confirm password:
Make admin [no]:
Success! Success!
This process uses a setting ``registration_shared_secret`` in This process uses a setting ``registration_shared_secret`` in

View File

@ -15,10 +15,172 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer from twisted.internet import defer
import ujson as json import ujson as json
import jsonschema
from jsonschema import FormatChecker
FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"ephemeral": {
"$ref": "#/definitions/room_event_filter"
},
"include_leave": {
"type": "boolean"
},
"state": {
"$ref": "#/definitions/room_event_filter"
},
"timeline": {
"$ref": "#/definitions/room_event_filter"
},
"account_data": {
"$ref": "#/definitions/room_event_filter"
},
}
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
"limit": {
"type": "number"
},
"senders": {
"$ref": "#/definitions/user_id_array"
},
"not_senders": {
"$ref": "#/definitions/user_id_array"
},
"types": {
"type": "array",
"items": {
"type": "string"
}
},
"not_types": {
"type": "array",
"items": {
"type": "string"
}
},
"rooms": {
"$ref": "#/definitions/room_id_array"
},
"not_rooms": {
"$ref": "#/definitions/room_id_array"
},
"contains_url": {
"type": "boolean"
}
}
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_user_id"
}
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
"items": {
"type": "string",
"format": "matrix_room_id"
}
}
USER_FILTER_SCHEMA = {
"$schema": "http://json-schema.org/draft-04/schema#",
"description": "schema for a Sync filter",
"type": "object",
"definitions": {
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
},
"properties": {
"presence": {
"$ref": "#/definitions/filter"
},
"account_data": {
"$ref": "#/definitions/filter"
},
"room": {
"$ref": "#/definitions/room_filter"
},
"event_format": {
"type": "string",
"enum": ["client", "federation"]
},
"event_fields": {
"type": "array",
"items": {
"type": "string",
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
"pattern": "^((?!\\\).)*$"
}
}
},
"additionalProperties": False
}
@FormatChecker.cls_checks('matrix_room_id')
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks('matrix_user_id')
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object): class Filtering(object):
@ -53,98 +215,11 @@ class Filtering(object):
# NB: Filters are the complete json blobs. "Definitions" are an # NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of # individual top-level key e.g. public_user_data. Filters are made of
# many definitions. # many definitions.
try:
top_level_definitions = [ jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
"presence", "account_data" format_checker=FormatChecker())
] except jsonschema.ValidationError as e:
raise SynapseError(400, e.message)
room_level_definitions = [
"state", "timeline", "ephemeral", "account_data"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
self._check_definition_room_lists(user_filter_json["room"])
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they
are present
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
self._check_definition_room_lists(definition)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
class FilterCollection(object): class FilterCollection(object):

View File

@ -29,6 +29,7 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
DirectoryStore, DirectoryStore,
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):

View File

@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore( class MediaRepositorySlavedStore(
SlavedApplicationServiceStore, SlavedApplicationServiceStore,
SlavedRegistrationStore, SlavedRegistrationStore,
TransactionStore,
BaseSlavedStore, BaseSlavedStore,
MediaRepositoryStore, MediaRepositoryStore,
ClientIpStore, ClientIpStore,

View File

@ -15,7 +15,6 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import ( from synapse.util.logcontext import (
@ -382,12 +381,6 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids): def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(server_name, key_ids): def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(

View File

@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent, builder from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
import copy import copy
import itertools import itertools
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=False): retry_on_dns_fail=False, ignore_backoff=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
handler name used in register_query_handler(). handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details args (dict): Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
a Deferred which will eventually yield a JSON object from the a Deferred which will eventually yield a JSON object from the
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
sent_queries_counter.inc(query_type) sent_queries_counter.inc(query_type)
return self.transport_layer.make_query( return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
@log_function @log_function
@ -234,13 +237,6 @@ class FederationClient(FederationBase):
continue continue
try: try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
with limiter:
transaction_data = yield self.transport_layer.get_event( transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout, destination, event_id, timeout=timeout,
) )

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
from twisted.internet import defer from twisted.internet import defer
@ -22,9 +22,7 @@ from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import ( from synapse.util.retryutils import NotRetryingDestination
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -311,14 +309,8 @@ class TransactionQueue(object):
# XXX: what's this for? # XXX: what's this for?
yield run_on_reactor() yield run_on_reactor()
pending_pdus = []
while True: while True:
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
@ -374,7 +366,6 @@ class TransactionQueue(object):
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
) )
if success: if success:
# Remove the acknowledged device messages from the database # Remove the acknowledged device messages from the database
@ -392,12 +383,24 @@ class TransactionQueue(object):
self.last_device_list_stream_id_by_dest[destination] = dev_list_id self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else: else:
break break
except NotRetryingDestination: except NotRetryingDestination as e:
logger.debug( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet (next retry at %s) - "
"dropping transaction for now", "dropping transaction for now",
destination, destination,
datetime.datetime.fromtimestamp(
(e.retry_last_ts + e.retry_interval) / 1000.0
),
) )
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
destination,
e,
)
for p in pending_pdus:
logger.info("Failed to send event %s to %s", p.event_id,
destination)
finally: finally:
# We want to be *very* sure we delete this after we stop processing # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None) self.pending_transactions.pop(destination, None)
@ -437,7 +440,7 @@ class TransactionQueue(object):
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, limiter): pending_failures):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -447,7 +450,6 @@ class TransactionQueue(object):
success = True success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id) txn_id = str(self._next_txn_id)
@ -488,7 +490,6 @@ class TransactionQueue(object):
len(failures), len(failures),
) )
with limiter:
# Actually send the transaction # Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age # FIXME (erikj): This is a bit of a hack to make the Pdu age
@ -548,31 +549,5 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
success = False success = False
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
defer.returnValue(success) defer.returnValue(success)

View File

@ -163,6 +163,7 @@ class TransportLayerClient(object):
data=json_data, data=json_data,
json_data_callback=json_data_callback, json_data_callback=json_data_callback,
long_retries=True, long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
) )
logger.debug( logger.debug(
@ -174,7 +175,8 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail): def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json( content = yield self.client.get_json(
@ -183,6 +185,7 @@ class TransportLayerClient(object):
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000, timeout=10000,
ignore_backoff=ignore_backoff,
) )
defer.returnValue(content) defer.returnValue(content)
@ -242,6 +245,7 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)
@ -269,6 +273,7 @@ class TransportLayerClient(object):
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args, args=args,
ignore_backoff=True,
) )
defer.returnValue(response) defer.returnValue(response)

View File

@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
logging.warn("Error retrieving alias") logging.warn("Error retrieving alias")

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -121,10 +121,6 @@ class E2eKeysHandler(object):
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries_not_in_cache[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.query_client_keys( remote_result = yield self.federation.query_client_keys(
destination, destination,
{"device_keys": destination_query}, {"device_keys": destination_query},
@ -239,10 +235,6 @@ class E2eKeysHandler(object):
def claim_client_keys(destination): def claim_client_keys(destination):
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys( remote_result = yield self.federation.claim_client_keys(
destination, destination,
{"one_time_keys": device_keys}, {"one_time_keys": device_keys},

View File

@ -575,8 +575,7 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
users = yield self.store.get_users_in_room(room_id) hosts = yield self.store.get_hosts_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)

View File

@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "displayname", "field": "displayname",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
args={ args={
"user_id": target_user.to_string(), "user_id": target_user.to_string(),
"field": "avatar_url", "field": "avatar_url",
} },
ignore_backoff=True,
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:

View File

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.client import readBody, HTTPConnectionPool, Agent
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
reactor, MatrixFederationEndpointFactory(hs), pool=pool reactor, MatrixFederationEndpointFactory(hs), pool=pool
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
@ -103,18 +103,40 @@ class MatrixFederationHttpClient(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True, query_bytes=b"", retry_on_dns_fail=True,
timeout=None, long_retries=False): timeout=None, long_retries=False,
""" Creates and sends a request to the given url ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
Returns: Returns:
Deferred: resolves with the http response object on success. Deferred: resolves with the http response object on success.
Fails with ``HTTPRequestException``: if we get an HTTP response Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300. code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
self._store,
backoff_on_404=backoff_on_404,
ignore_backoff=ignore_backoff,
)
destination = destination.encode("ascii")
path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
@ -150,8 +172,7 @@ class MatrixFederationHttpClient(object):
try: try:
def send_request(): def send_request():
request_deferred = preserve_context_over_fn( request_deferred = self.agent.request(
self.agent.request,
method, method,
url_bytes, url_bytes,
Headers(headers_dict), Headers(headers_dict),
@ -163,7 +184,8 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn(send_request) with logcontext.PreserveLoggingContext():
response = yield send_request()
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break
@ -220,7 +242,8 @@ class MatrixFederationHttpClient(object):
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -254,7 +277,9 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False, timeout=None): long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -269,11 +294,19 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests)
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
if not json_data_callback: if not json_data_callback:
@ -288,26 +321,29 @@ class MatrixFederationHttpClient(object):
producer = _JsonProducer(json_data) producer = _JsonProducer(json_data)
return producer return producer
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"PUT", "PUT",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
backoff_on_404=backoff_on_404,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False, def post_json(self, destination, path, data={}, long_retries=False,
timeout=None): timeout=None, ignore_backoff=False):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -320,11 +356,15 @@ class MatrixFederationHttpClient(object):
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised. CodeMessageException is raised.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
def body_callback(method, url_bytes, headers_dict): def body_callback(method, url_bytes, headers_dict):
@ -333,27 +373,29 @@ class MatrixFederationHttpClient(object):
) )
return _JsonProducer(data) return _JsonProducer(data)
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"POST", "POST",
path.encode("ascii"), path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True, def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None): timeout=None, ignore_backoff=False):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
Args: Args:
@ -365,11 +407,16 @@ class MatrixFederationHttpClient(object):
timeout (int): How long to try (in ms) the destination for before timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will giving up. None indicates no timeout and that the request will
be retried. be retried.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`, The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body. where `response` is a dict representing the decoded JSON body.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
logger.debug("get_json args: %s", args) logger.debug("get_json args: %s", args)
@ -386,39 +433,47 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
check_content_type_is_json(response.headers) check_content_type_is_json(response.headers)
body = yield preserve_context_over_fn(readBody, response) with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={}, def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None): retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
"""GETs a file from a given homeserver """GETs a file from a given homeserver
Args: Args:
destination (str): The remote server to send the HTTP request to. destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET. path (str): The HTTP path to GET.
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string. args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns: Returns:
Deferred: resolves with an (int,dict) tuple of the file length and Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers. a dict of the response headers.
Fails with ``HTTPRequestException`` if we get an HTTP response code Fails with ``HTTPRequestException`` if we get an HTTP response code
>= 300 >= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
""" """
encoded_args = {} encoded_args = {}
@ -434,20 +489,21 @@ class MatrixFederationHttpClient(object):
self.sign_request(destination, method, url_bytes, headers_dict) self.sign_request(destination, method, url_bytes, headers_dict)
return None return None
response = yield self._create_request( response = yield self._request(
destination.encode("ascii"), destination,
"GET", "GET",
path.encode("ascii"), path,
query_bytes=query_bytes, query_bytes=query_bytes,
body_callback=body_callback, body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
) )
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
length = yield preserve_context_over_fn( with logcontext.PreserveLoggingContext():
_readBodyToFile, length = yield _readBodyToFile(
response, output_stream, max_size response, output_stream, max_size
) )
except: except:

View File

@ -19,6 +19,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],

View File

@ -167,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__ _get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__

View File

@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
self.presence_stream_cache.entity_has_changed( self.presence_stream_cache.entity_has_changed(
user_id, position user_id, position
) )
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result) return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -268,7 +268,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if existingUid is not None: if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestMsisdnToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))

View File

@ -537,7 +537,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
self.device_handler.check_device_registered( yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )

View File

@ -73,6 +73,9 @@ class LoggingTransaction(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self.txn, name, value) setattr(self.txn, name, value)
def __iter__(self):
return self.txn.__iter__()
def execute(self, sql, *args): def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args) self._do_execute(self.txn.execute, sql, *args)
@ -132,7 +135,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3): def interval(self, interval_duration, limit=3):
counters = [] counters = []
for name, (count, cum_time) in self.current_counters.items(): for name, (count, cum_time) in self.current_counters.iteritems():
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(( counters.append((
(cum_time - prev_time) / interval_duration, (cum_time - prev_time) / interval_duration,
@ -357,7 +360,7 @@ class SQLBaseStore(object):
""" """
col_headers = list(intern(column[0]) for column in cursor.description) col_headers = list(intern(column[0]) for column in cursor.description)
results = list( results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall() dict(zip(col_headers, row)) for row in cursor
) )
return results return results
@ -565,7 +568,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -579,7 +582,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol, def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"): desc="_simple_select_onecol"):
@ -712,7 +715,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.items(): for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -753,7 +756,7 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_update_one_txn(txn, table, keyvalues, updatevalues): def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else: else:
where = "" where = ""
@ -870,7 +873,7 @@ class SQLBaseStore(object):
) )
values.extend(iterable) values.extend(iterable)
for key, value in keyvalues.items(): for key, value in keyvalues.iteritems():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -901,16 +904,16 @@ class SQLBaseStore(object):
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
for row in rows for row in txn
} }
txn.close()
if cache: if cache:
min_val = min(cache.values()) min_val = min(cache.itervalues())
else: else:
min_val = max_value min_val = max_value

View File

@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
global_account_data = { global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall() row[0]: json.loads(row[1]) for row in txn
} }
sql = ( sql = (
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
account_data_by_room = {} account_data_by_room = {}
for row in txn.fetchall(): for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2]) room_account_data[row[1]] = json.loads(row[2])

View File

@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
message_json = ujson.dumps(messages_by_device["*"]) message_json = ujson.dumps(messages_by_device["*"])
for row in txn.fetchall(): for row in txn:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] device = row[0]
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# TODO: Maybe this needs to be done in batches if there are # TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user. # too many local devices for a given user.
txn.execute(sql, [user_id] + devices) txn.execute(sql, [user_id] + devices)
for row in txn.fetchall(): for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device = row[0]
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
user_id, device_id, last_stream_id, current_stream_id, limit user_id, device_id, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" ORDER BY stream_id ASC" " ORDER BY stream_id ASC"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn.fetchall()) rows.extend(txn)
return rows return rows
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
destination, last_stream_id, current_stream_id, limit destination, last_stream_id, current_stream_id, limit
)) ))
messages = [] messages = []
for row in txn.fetchall(): for row in txn:
stream_pos = row[0] stream_pos = row[0]
messages.append(ujson.loads(row[1])) messages.append(ujson.loads(row[1]))
if len(messages) < limit: if len(messages) < limit:

View File

@ -329,17 +329,20 @@ class DeviceStore(SQLBaseStore):
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id GROUP BY user_id, device_id
LIMIT 20
""" """
txn.execute( txn.execute(
sql, (destination, from_stream_id, now_stream_id, False) sql, (destination, from_stream_id, now_stream_id, False)
) )
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows} query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True
) )

View File

@ -193,7 +193,7 @@ class EndToEndKeyStore(SQLBaseStore):
) )
txn.execute(sql, (user_id, device_id)) txn.execute(sql, (user_id, device_id))
result = {} result = {}
for algorithm, key_count in txn.fetchall(): for algorithm, key_count in txn:
result[algorithm] = key_count result[algorithm] = key_count
return result return result
return self.runInteraction( return self.runInteraction(
@ -214,7 +214,7 @@ class EndToEndKeyStore(SQLBaseStore):
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall(): for key_id, key_json in txn:
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))
sql = ( sql = (

View File

@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
base_sql % (",".join(["?"] * len(chunk)),), base_sql % (",".join(["?"] * len(chunk)),),
chunk chunk
) )
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn])
new_front -= results new_front -= results
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, False,)) txn.execute(sql, (room_id, False,))
return dict(txn.fetchall()) return dict(txn)
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore):
txn.execute(sql, (room_id, )) txn.execute(sql, (room_id, ))
results = [] results = []
for event_id, depth in txn.fetchall(): for event_id, depth in txn:
hashes = self._get_event_reference_hashes_txn(txn, event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items()
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
def get_forward_extremeties_for_room_txn(txn): def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall() return [event_id for event_id, in txn]
return [event_id for event_id, in rows]
return self.runInteraction( return self.runInteraction(
"get_forward_extremeties_for_room", "get_forward_extremeties_for_room",
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for row in txn.fetchall(): for row in txn:
if row[1] not in event_results: if row[1] not in event_results:
queue.put((-row[0], row[1])) queue.put((-row[0], row[1]))
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
(room_id, event_id, False, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn:
new_front.add(e_id) new_front.add(e_id)
new_front -= earliest_events new_front -= earliest_events

View File

@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
" stream_ordering >= ? AND stream_ordering <= ?" " stream_ordering >= ? AND stream_ordering <= ?"
) )
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f) ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret) defer.returnValue(ret)

View File

@ -217,14 +217,14 @@ class EventsStore(SQLBaseStore):
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = [] deferreds = []
for room_id, evs_ctxs in partitioned.items(): for room_id, evs_ctxs in partitioned.iteritems():
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
) )
deferreds.append(d) deferreds.append(d)
for room_id in partitioned.keys(): for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
return preserve_context_over_deferred( return preserve_context_over_deferred(
@ -323,7 +323,7 @@ class EventsStore(SQLBaseStore):
(event, context) (event, context)
) )
for room_id, ev_ctx_rm in events_by_room.items(): for room_id, ev_ctx_rm in events_by_room.iteritems():
# Work out new extremities by recursively adding and removing # Work out new extremities by recursively adding and removing
# the new events. # the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room( latest_event_ids = yield self.get_latest_event_ids_in_room(
@ -428,6 +428,7 @@ class EventsStore(SQLBaseStore):
# Now we need to work out the different state sets for # Now we need to work out the different state sets for
# each state extremities # each state extremities
state_sets = [] state_sets = []
state_groups = set()
missing_event_ids = [] missing_event_ids = []
was_updated = False was_updated = False
for event_id in new_latest_event_ids: for event_id in new_latest_event_ids:
@ -437,9 +438,17 @@ class EventsStore(SQLBaseStore):
if event_id == ev.event_id: if event_id == ev.event_id:
if ctx.current_state_ids is None: if ctx.current_state_ids is None:
raise Exception("Unknown current state") raise Exception("Unknown current state")
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_sets.append(ctx.current_state_ids) state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"): if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True was_updated = True
if ctx.state_group:
# Add this as a seen state group (if it has a state
# group)
state_groups.add(ctx.state_group)
break break
else: else:
# If we couldn't find it, then we'll need to pull # If we couldn't find it, then we'll need to pull
@ -453,14 +462,20 @@ class EventsStore(SQLBaseStore):
missing_event_ids, missing_event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues()) - state_groups
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values()) if groups:
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids: if not new_latest_event_ids:
current_state = {} current_state = {}
elif was_updated: elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
current_state = state_sets[0]
else:
# We work out the current state by passing the state sets to the # We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including # state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more # the events we have yet to persist, so we need a slightly more
@ -718,7 +733,7 @@ class EventsStore(SQLBaseStore):
def _update_forward_extremities_txn(self, txn, new_forward_extremities, def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order): max_stream_order):
for room_id, new_extrem in new_forward_extremities.items(): for room_id, new_extrem in new_forward_extremities.iteritems():
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="event_forward_extremities", table="event_forward_extremities",
@ -736,7 +751,7 @@ class EventsStore(SQLBaseStore):
"event_id": ev_id, "event_id": ev_id,
"room_id": room_id, "room_id": room_id,
} }
for room_id, new_extrem in new_forward_extremities.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for ev_id in new_extrem for ev_id in new_extrem
], ],
) )
@ -753,7 +768,7 @@ class EventsStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"stream_ordering": max_stream_order, "stream_ordering": max_stream_order,
} }
for room_id, new_extrem in new_forward_extremities.items() for room_id, new_extrem in new_forward_extremities.iteritems()
for event_id in new_extrem for event_id in new_extrem
] ]
) )
@ -807,7 +822,7 @@ class EventsStore(SQLBaseStore):
event.depth, depth_updates.get(event.room_id, event.depth) event.depth, depth_updates.get(event.room_id, event.depth)
) )
for room_id, depth in depth_updates.items(): for room_id, depth in depth_updates.iteritems():
self._update_min_depth_for_room_txn(txn, room_id, depth) self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts): def _update_outliers_txn(self, txn, events_and_contexts):
@ -834,7 +849,7 @@ class EventsStore(SQLBaseStore):
have_persisted = { have_persisted = {
event_id: outlier event_id: outlier
for event_id, outlier in txn.fetchall() for event_id, outlier in txn
} }
to_remove = set() to_remove = set()
@ -958,14 +973,10 @@ class EventsStore(SQLBaseStore):
return return
def event_dict(event): def event_dict(event):
return { d = event.get_dict()
k: v d.pop("redacted", None)
for k, v in event.get_dict().items() d.pop("redacted_because", None)
if k not in [ return d
"redacted",
"redacted_because",
]
}
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -1998,7 +2009,7 @@ class EventsStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in curr_state.items() for key, state_id in curr_state.iteritems()
], ],
) )

View File

@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
), ),
(current_version,) (current_version,)
) )
applied_deltas = [d for d, in txn.fetchall()] applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded return current_version, applied_deltas, upgraded
return None return None

View File

@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed, self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id, state.user_id, stream_id,
) )
self._invalidate_cache_and_stream( txn.call_after(
txn, self._get_presence_for_user, (state.user_id,) self._get_presence_for_user.invalidate, (state.user_id,)
) )
# Actually insert new rows # Actually insert new rows

View File

@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
) )
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results and topological_ordering: if topological_ordering:
for to, so, _ in results: for to, so, _ in txn:
if int(to) > topological_ordering: if int(to) > topological_ordering:
return False return False
elif int(to) == topological_ordering and int(so) >= stream_ordering: elif int(to) == topological_ordering and int(so) >= stream_ordering:

View File

@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
" WHERE lower(name) = lower(?)" " WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn.fetchall()) return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f) return self.runInteraction("get_users_by_id_case_insensitive", f)

View File

@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
sql % ("AND appservice_id IS NULL",), sql % ("AND appservice_id IS NULL",),
(stream_id,) (stream_id,)
) )
return dict(txn.fetchall()) return dict(txn)
else: else:
# We want to get from all lists, so we need to aggregate the results # We want to get from all lists, so we need to aggregate the results
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
results = {} results = {}
# A room is visible if its visible on any list. # A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall(): for room_id, visibility in txn:
results[room_id] = bool(visibility) or results.get(room_id, False) results[room_id] = bool(visibility) or results.get(room_id, False)
return results return results

View File

@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering) yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
sql = (
rows = self._get_members_rows_txn( "SELECT m.user_id FROM room_memberships as m"
txn, " INNER JOIN current_state_events as c"
room_id=room_id, " ON m.event_id = c.event_id "
membership=Membership.JOIN, " AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
) )
return [r["user_id"] for r in rows] txn.execute(sql, (room_id, Membership.JOIN,))
return [r[0] for r in txn]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
@cached() @cached()
@ -246,34 +259,6 @@ class RoomMemberStore(SQLBaseStore):
return results return results
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = (
"SELECT m.* FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn)
return rows
@cachedInlineCallbacks(max_entries=500000, iterable=True) @cachedInlineCallbacks(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
"""Returns a set of room_ids the user is currently joined to """Returns a set of room_ids the user is currently joined to

View File

@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn}
def _store_event_reference_hashes_txn(self, txn, events): def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU """Store a hash for a PDU

View File

@ -90,7 +90,7 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state) defer.returnValue(group_to_state)
@ -108,17 +108,18 @@ class StateStore(SQLBaseStore):
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ [
ev_id for group_ids in group_to_ids.values() ev_id for group_ids in group_to_ids.itervalues()
for ev_id in group_ids.values() for ev_id in group_ids.itervalues()
], ],
get_prev_content=False get_prev_content=False
) )
defer.returnValue({ defer.returnValue({
group: [ group: [
state_event_map[v] for v in event_id_map.values() if v in state_event_map state_event_map[v] for v in event_id_map.itervalues()
if v in state_event_map
] ]
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.iteritems()
}) })
def _have_persisted_state_group_txn(self, txn, state_group): def _have_persisted_state_group_txn(self, txn, state_group):
@ -190,7 +191,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.delta_ids.items() for key, state_id in context.delta_ids.iteritems()
], ],
) )
else: else:
@ -205,7 +206,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in context.current_state_ids.items() for key, state_id in context.current_state_ids.iteritems()
], ],
) )
@ -217,7 +218,7 @@ class StateStore(SQLBaseStore):
"state_group": state_group_id, "state_group": state_group_id,
"event_id": event_id, "event_id": event_id,
} }
for event_id, state_group_id in state_groups.items() for event_id, state_group_id in state_groups.iteritems()
], ],
) )
@ -341,10 +342,10 @@ class StateStore(SQLBaseStore):
args.extend(where_args) args.extend(where_args)
txn.execute(sql % (where_clause,), args) txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn) for row in txn:
for row in rows: typ, state_key, event_id = row
key = (row["type"], row["state_key"]) key = (typ, state_key)
results[group][key] = row["event_id"] results[group][key] = event_id
else: else:
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( where_clause = "AND (%s)" % (
@ -373,12 +374,11 @@ class StateStore(SQLBaseStore):
" WHERE state_group = ? %s" % (where_clause,), " WHERE state_group = ? %s" % (where_clause,),
args args
) )
rows = txn.fetchall() results[group].update(
results[group].update({ ((typ, state_key), event_id)
(typ, state_key): event_id for typ, state_key, event_id in txn
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
}) )
# If the lengths match then we must have all the types, # If the lengths match then we must have all the types,
# so no need to go walk further down the tree. # so no need to go walk further down the tree.
@ -415,21 +415,21 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
get_prev_content=False get_prev_content=False
) )
event_to_state = { event_to_state = {
event_id: { event_id: {
k: state_event_map[v] k: state_event_map[v]
for k, v in group_to_state[group].items() for k, v in group_to_state[group].iteritems()
if v in state_event_map if v in state_event_map
} }
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -452,12 +452,12 @@ class StateStore(SQLBaseStore):
event_ids, event_ids,
) )
groups = set(event_to_groups.values()) groups = set(event_to_groups.itervalues())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.iteritems()
} }
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@ -569,7 +569,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None) got_all = not (missing_types or types is None)
return { return {
k: v for k, v in state_dict_ids.items() k: v for k, v in state_dict_ids.iteritems()
if include(k[0], k[1]) if include(k[0], k[1])
}, missing_types, got_all }, missing_types, got_all
@ -628,7 +628,7 @@ class StateStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched # Now we want to update the cache with all the things we fetched
# from the database. # from the database.
for group, group_state_dict in group_to_state_dict.items(): for group, group_state_dict in group_to_state_dict.iteritems():
if types: if types:
# We delibrately put key -> None mappings into the cache to # We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've # cache absence of the key, on the assumption that if we've
@ -643,10 +643,10 @@ class StateStore(SQLBaseStore):
else: else:
state_dict = results[group] state_dict = results[group]
state_dict.update({ state_dict.update(
(intern_string(k[0]), intern_string(k[1])): v ((intern_string(k[0]), intern_string(k[1])), v)
for k, v in group_state_dict.items() for k, v in group_state_dict.iteritems()
}) )
self._state_group_cache.update( self._state_group_cache.update(
cache_seq_num, cache_seq_num,
@ -657,10 +657,10 @@ class StateStore(SQLBaseStore):
# Remove all the entries with None values. The None values were just # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache. # used for bookkeeping in the cache.
for group, state_dict in results.items(): for group, state_dict in results.iteritems():
results[group] = { results[group] = {
key: event_id key: event_id
for key, event_id in state_dict.items() for key, event_id in state_dict.iteritems()
if event_id if event_id
} }
@ -749,7 +749,7 @@ class StateStore(SQLBaseStore):
# of keys # of keys
delta_state = { delta_state = {
key: value for key, value in curr_state.items() key: value for key, value in curr_state.iteritems()
if prev_state.get(key, None) != value if prev_state.get(key, None) != value
} }
@ -789,7 +789,7 @@ class StateStore(SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in delta_state.items() for key, state_id in delta_state.iteritems()
], ],
) )

View File

@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
for stream_id, user_id, room_id in tag_ids: for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
tags = [] tags = []
for tag, content in txn.fetchall(): for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content) tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}" tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json)) results.append((stream_id, user_id, room_id, tag_json))
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND stream_id > ?" " WHERE user_id = ? AND stream_id > ?"
) )
txn.execute(sql, (user_id, stream_id)) txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()] room_ids = [row[0] for row in txn]
return room_ids return room_ids
changed = self._account_data_stream_cache.has_entity_changed( changed = self._account_data_stream_cache.has_entity_changed(

View File

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class DeferredTimedOutError(SynapseError): class DeferredTimedOutError(SynapseError):
def __init__(self): def __init__(self):
super(SynapseError).__init__(504, "Timed out") super(SynapseError, self).__init__(504, "Timed out")
def unwrapFirstError(failure): def unwrapFirstError(failure):
@ -93,8 +93,10 @@ class Clock(object):
ret_deferred = defer.Deferred() ret_deferred = defer.Deferred()
def timed_out_fn(): def timed_out_fn():
e = DeferredTimedOutError()
try: try:
ret_deferred.errback(DeferredTimedOutError()) ret_deferred.errback(e)
except: except:
pass pass
@ -114,7 +116,7 @@ class Clock(object):
ret_deferred.addBoth(cancel) ret_deferred.addBoth(cancel)
def sucess(res): def success(res):
try: try:
ret_deferred.callback(res) ret_deferred.callback(res)
except: except:
@ -128,7 +130,7 @@ class Clock(object):
except: except:
pass pass
given_deferred.addCallbacks(callback=sucess, errback=err) given_deferred.addCallbacks(callback=success, errback=err)
timer = self.call_later(time_out, timed_out_fn) timer = self.call_later(time_out, timed_out_fn)

View File

@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, **kwargs): def get_retry_limiter(destination, clock, store, ignore_backoff=False,
**kwargs):
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
that will mark the destination as down if an exception is thrown (excluding that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500) CodeMessageException with code < 500)
Args:
destination (str): name of homeserver
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. We will still update the next
retry_interval on success/failure.
Example usage: Example usage:
try: try:
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
now = int(clock.time_msec()) now = int(clock.time_msec())
if retry_last_ts + retry_interval > now: if not ignore_backoff and retry_last_ts + retry_interval > now:
raise NotRetryingDestination( raise NotRetryingDestination(
retry_last_ts=retry_last_ts, retry_last_ts=retry_last_ts,
retry_interval=retry_interval, retry_interval=retry_interval,
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException): if exc_type is None:
valid_err_code = True
elif not issubclass(exc_type, Exception):
# avoid treating exceptions which don't derive from Exception as
# failures; this is mostly so as not to catch defer._DefGen.
valid_err_code = True
elif issubclass(exc_type, CodeMessageException):
# Some error codes are perfectly fine for some APIs, whereas other # Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to # APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS # handle 404 as some remote servers will return a 404 when the HS
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
else: else:
valid_err_code = False valid_err_code = False
if exc_type is None or valid_err_code: if valid_err_code:
# We connected successfully. # We connected successfully.
if not self.retry_interval: if not self.retry_interval:
return return
logger.debug("Connection to %s was successful; clearing backoff",
self.destination)
retry_last_ts = 0 retry_last_ts = 0
self.retry_interval = 0 self.retry_interval = 0
else: else:
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
else: else:
self.retry_interval = self.min_retry_interval self.retry_interval = self.min_retry_interval
logger.debug(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval
)
retry_last_ts = int(self.clock.time_msec()) retry_last_ts = int(self.clock.time_msec())
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
if prev_membership not in MEMBERSHIP_PRIORITY: if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave" prev_membership = "leave"
# Always allow the user to see their own leave events, otherwise
# they won't see the room disappear if they reject the invite
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
return True
new_priority = MEMBERSHIP_PRIORITY.index(membership) new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority: if old_priority < new_priority:

View File

@ -23,6 +23,9 @@ from tests.utils import (
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.errors import SynapseError
import jsonschema
user_localpart = "test_user" user_localpart = "test_user"
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
def test_errors_on_invalid_filters(self):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
{"event_fields": ["\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}}
]
for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error:
self.filtering.check_valid_filter(filter)
self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self):
valid_filters = [
{
"room": {
"timeline": {"limit": 20},
"state": {"not_types": ["m.room.member"]},
"ephemeral": {"limit": 0, "not_types": ["*"]},
"include_leave": False,
"rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]}
}
},
{
"room": {
"state": {
"types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"]
},
"timeline": {
"limit": 10,
"types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
},
"ephemeral": {
"types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"]
}
},
"presence": {
"types": ["m.presence"],
"not_senders": ["@alice:example.com"]
},
"event_format": "client",
"event_fields": ["type", "content", "sender"]
}
]
for filter in valid_filters:
try:
self.filtering.check_valid_filter(filter)
except jsonschema.ValidationError as e:
self.fail(e)
def test_limits_are_applied(self):
# TODO
pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self):
definition = { definition = {
"types": ["m.room.message", "org.matrix.foo.bar"] "types": ["m.room.message", "org.matrix.foo.bar"]

View File

@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
"room_alias": "#another:remote", "room_alias": "#another:remote",
}, },
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
query_type="profile", query_type="profile",
args={"user_id": "@alice:remote", "field": "displayname"} args={"user_id": "@alice:remote", "field": "displayname"},
ignore_backoff=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View File

@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_one_1col(self): def test_select_one_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
self.mock_txn.fetchall.return_value = [("Value",)] self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol( value = yield self.datastore._simple_select_one_onecol(
table="tablename", table="tablename",
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_select_list(self): def test_select_list(self):
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,)) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = ( self.mock_txn.description = (
("colA", None, None, None, None, None, None), ("colA", None, None, None, None, None, None),
) )

33
tests/util/test_clock.py Normal file
View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse import util
from twisted.internet import defer
from tests import unittest
class ClockTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_time_bound_deferred(self):
# just a deferred which never resolves
slow_deferred = defer.Deferred()
clock = util.Clock()
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
try:
yield time_bound
self.fail("Expected timedout error, but got nothing")
except util.DeferredTimedOutError:
pass