Merge branch 'test-sqlite-memory' of github.com:matrix-org/synapse into develop

Conflicts:
	tests/handlers/test_profile.py
This commit is contained in:
Erik Johnston 2014-09-18 14:31:47 +01:00
commit 335e5d131c
23 changed files with 1066 additions and 300 deletions

View File

@ -156,7 +156,8 @@ class SynapseEvent(JsonEncodedObject):
return "Missing %s key" % key return "Missing %s key" % key
if type(content[key]) != type(template[key]): if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type." % key return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict: if type(content[key]) == dict:
# we must go deeper # we must go deeper

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import read_schema from synapse.storage import prepare_database
from synapse.server import HomeServer from synapse.server import HomeServer
@ -36,30 +36,14 @@ from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
import logging import logging
import sqlite3
import os import os
import re import re
import sys import sys
import sqlite3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 3
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
@ -80,52 +64,12 @@ class SynapseHomeServer(HomeServer):
) )
def build_db_pool(self): def build_db_pool(self):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we return adbapi.ConnectionPool(
don't have to worry about overwriting existing content. "sqlite3", self.get_db_name(),
""" check_same_thread=False,
logging.info("Preparing database: %s...", self.db_name) cp_min=1,
cp_max=1
with sqlite3.connect(self.db_name) as db_conn: )
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
logging.info("Database prepared in %s.", self.db_name)
pool = adbapi.ConnectionPool(
'sqlite3', self.db_name, check_same_thread=False,
cp_min=1, cp_max=1)
return pool
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server. """Create the resource tree for this Home Server.
@ -230,10 +174,6 @@ class SynapseHomeServer(HomeServer):
logger.info("Synapse now listening on port %d", unsecure_port) logger.info("Synapse now listening on port %d", unsecure_port)
def run():
reactor.run()
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -268,7 +208,15 @@ def setup():
web_client=config.webclient, web_client=config.webclient,
redirect_root_to_web_client=True, redirect_root_to_web_client=True,
) )
hs.start_listening(config.bind_port, config.unsecure_port)
db_name = hs.get_db_name()
logging.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn)
logging.info("Database prepared in %s.", db_name)
hs.get_db_pool() hs.get_db_pool()
@ -279,12 +227,14 @@ def setup():
f.namespace['hs'] = hs f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1') reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening(config.bind_port, config.unsecure_port)
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=reactor.run,
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -292,7 +242,7 @@ def setup():
daemon.start() daemon.start()
else: else:
run() reactor.run()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -58,6 +58,7 @@ class BaseHomeServer(object):
DEPENDENCIES = [ DEPENDENCIES = [
'clock', 'clock',
'http_client', 'http_client',
'db_name',
'db_pool', 'db_pool',
'persistence_service', 'persistence_service',
'replication_layer', 'replication_layer',

View File

@ -47,6 +47,23 @@ import os
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
"im",
"room_aliases",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 3
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying
something went wrong. something went wrong.
@ -78,7 +95,7 @@ class DataStore(RoomMemberStore, RoomStore,
stream_ordering = self.min_token stream_ordering = self.min_token
try: try:
yield self._db_pool.runInteraction( yield self.runInteraction(
self._persist_pdu_event_txn, self._persist_pdu_event_txn,
pdu=pdu, pdu=pdu,
event=event, event=event,
@ -291,7 +308,7 @@ class DataStore(RoomMemberStore, RoomStore,
prev_state_pdu=prev_state_pdu, prev_state_pdu=prev_state_pdu,
) )
return self._db_pool.runInteraction(_snapshot) return self.runInteraction(_snapshot)
class Snapshot(object): class Snapshot(object):
@ -361,3 +378,42 @@ def read_schema(schema):
""" """
with open(schema_path(schema)) as schema_file: with open(schema_path(schema)) as schema_file:
return schema_file.read() return schema_file.read()
def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content.
"""
c = db_conn.cursor()
c.execute("PRAGMA user_version")
row = c.fetchone()
if row and row[0]:
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
user_version
)
# Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1):
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit()
else:
for sql_loc in SCHEMAS:
sql_script = read_schema(sql_loc)
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()

View File

@ -26,6 +26,44 @@ import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
__slots__ = ["txn"]
def __init__(self, txn):
object.__setattr__(self, "txn", txn)
def __getattribute__(self, name):
if name == "execute":
return object.__getattribute__(self, "execute")
return getattr(object.__getattribute__(self, "txn"), name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "txn"), name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] %s", sql)
try:
if args and args[0]:
values = args[0]
sql_logger.debug("[SQL values] " +
", ".join(("<%s>",) * len(values)), *values)
except:
# Don't let logging failures stop SQL from working
pass
# TODO(paul): Here would be an excellent place to put some timing
# measurements, and log (warning?) slow queries.
return object.__getattribute__(self, "txn").execute(
sql, *args, **kwargs
)
class SQLBaseStore(object): class SQLBaseStore(object):
@ -35,6 +73,13 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock() self._clock = hs.get_clock()
def runInteraction(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
return func(LoggingTransaction(txn), *args, **kwargs)
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
def cursor_to_dict(self, cursor): def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts. """Converts a SQL cursor into an list of dicts.
@ -60,11 +105,6 @@ class SQLBaseStore(object):
Returns: Returns:
The result of decoder(results) The result of decoder(results)
""" """
logger.debug(
"[SQL] %s Args=%s Func=%s",
query, args, decoder.__name__ if decoder else None
)
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) cursor = txn.execute(query, args)
if decoder: if decoder:
@ -72,7 +112,7 @@ class SQLBaseStore(object):
else: else:
return cursor.fetchall() return cursor.fetchall()
return self._db_pool.runInteraction(interaction) return self.runInteraction(interaction)
def _execute_and_decode(self, query, *args): def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args) return self._execute(self.cursor_to_dict, query, *args)
@ -88,7 +128,7 @@ class SQLBaseStore(object):
values : dict of new column names and values for them values : dict of new column names and values for them
or_replace : bool; if True performs an INSERT OR REPLACE or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace self._simple_insert_txn, table, values, or_replace=or_replace
) )
@ -172,7 +212,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return txn.fetchall() return txn.fetchall()
res = yield self._db_pool.runInteraction(func) res = yield self.runInteraction(func)
defer.returnValue([r[0] for r in res]) defer.returnValue([r[0] for r in res])
@ -195,7 +235,7 @@ class SQLBaseStore(object):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None): retcols=None):
@ -263,7 +303,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched") raise StoreError(500, "More than one row matched")
return ret return ret
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_delete_one(self, table, keyvalues): def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
@ -284,7 +324,7 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _simple_max_id(self, table): def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the """Executes a SELECT query on the named table, expecting to return the
@ -302,7 +342,7 @@ class SQLBaseStore(object):
return 0 return 0
return max_id return max_id
return self._db_pool.runInteraction(func) return self.runInteraction(func)
def _parse_event_from_row(self, row_dict): def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items() if v}) d = copy.deepcopy({k: v for k, v in row_dict.items() if v})
@ -325,7 +365,7 @@ class SQLBaseStore(object):
) )
def _parse_events(self, rows): def _parse_events(self, rows):
return self._db_pool.runInteraction(self._parse_events_txn, rows) return self.runInteraction(self._parse_events_txn, rows)
def _parse_events_txn(self, txn, rows): def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows] events = [self._parse_event_from_row(r) for r in rows]

View File

@ -43,7 +43,7 @@ class PduStore(SQLBaseStore):
PduTuple: If the pdu does not exist in the database, returns None PduTuple: If the pdu does not exist in the database, returns None
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin self._get_pdu_tuple, pdu_id, origin
) )
@ -95,7 +95,7 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples list: A list of PduTuples
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_current_state_for_context, self._get_current_state_for_context,
context context
) )
@ -143,7 +143,7 @@ class PduStore(SQLBaseStore):
pdu_origin (str) pdu_origin (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin self._mark_as_processed, pdu_id, pdu_origin
) )
@ -152,7 +152,7 @@ class PduStore(SQLBaseStore):
def get_all_pdus_from_context(self, context): def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context.""" """Get a list of all PDUs for a given context."""
return self._db_pool.runInteraction( return self.runInteraction(
self._get_all_pdus_from_context, context, self._get_all_pdus_from_context, context,
) )
@ -179,7 +179,7 @@ class PduStore(SQLBaseStore):
Return: Return:
list: A list of PduTuples list: A list of PduTuples
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_backfill, context, pdu_list, limit self._get_backfill, context, pdu_list, limit
) )
@ -240,7 +240,7 @@ class PduStore(SQLBaseStore):
txn txn
context (str) context (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_min_depth_for_context, context self._get_min_depth_for_context, context
) )
@ -346,7 +346,7 @@ class PduStore(SQLBaseStore):
bool bool
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._is_pdu_new, self._is_pdu_new,
pdu_id=pdu_id, pdu_id=pdu_id,
origin=origin, origin=origin,
@ -499,7 +499,7 @@ class StatePduStore(SQLBaseStore):
) )
def get_unresolved_state_tree(self, new_state_pdu): def get_unresolved_state_tree(self, new_state_pdu):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu self._get_unresolved_state_tree, new_state_pdu
) )
@ -538,7 +538,7 @@ class StatePduStore(SQLBaseStore):
def update_current_state(self, pdu_id, origin, context, pdu_type, def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key): state_key):
return self._db_pool.runInteraction( return self.runInteraction(
self._update_current_state, self._update_current_state,
pdu_id, origin, context, pdu_type, state_key pdu_id, origin, context, pdu_type, state_key
) )
@ -577,7 +577,7 @@ class StatePduStore(SQLBaseStore):
PduEntry PduEntry
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key self._get_current_state_pdu, context, pdu_type, state_key
) )
@ -636,7 +636,7 @@ class StatePduStore(SQLBaseStore):
Returns: Returns:
bool: True if the new_pdu clobbered the current state, False if not bool: True if the new_pdu clobbered the current state, False if not
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._handle_new_state, new_pdu self._handle_new_state, new_pdu
) )

View File

@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
yield self._db_pool.runInteraction(self._register, user_id, token, yield self.runInteraction(self._register, user_id, token,
password_hash) password_hash)
def _register(self, txn, user_id, token, password_hash): def _register(self, txn, user_id, token, password_hash):
@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if no user was found. StoreError if no user was found.
""" """
user_id = yield self._db_pool.runInteraction(self._query_for_auth, user_id = yield self.runInteraction(self._query_for_auth,
token) token)
defer.returnValue(user_id) defer.returnValue(user_id)

View File

@ -149,7 +149,7 @@ class RoomStore(SQLBaseStore):
defer.returnValue(None) defer.returnValue(None)
def get_power_level(self, room_id, user_id): def get_power_level(self, room_id, user_id):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_power_level, self._get_power_level,
room_id, user_id, room_id, user_id,
) )
@ -182,7 +182,7 @@ class RoomStore(SQLBaseStore):
return None return None
def get_ops_levels(self, room_id): def get_ops_levels(self, room_id):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_ops_levels, self._get_ops_levels,
room_id, room_id,
) )

View File

@ -149,7 +149,7 @@ class RoomMemberStore(SQLBaseStore):
membership_list (list): A list of synapse.api.constants.Membership membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in. values which the user must be in.
Returns: Returns:
A list of dicts with "room_id" and "membership" keys. A list of RoomMemberEvent objects
""" """
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
@ -198,10 +198,11 @@ class RoomMemberStore(SQLBaseStore):
return results return results
@defer.inlineCallbacks @defer.inlineCallbacks
def user_rooms_intersect(self, user_list): def user_rooms_intersect(self, user_id_list):
""" Checks whether a list of users share a room. """ Checks whether all the users whose IDs are given in a list share a
room.
""" """
user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_list)) user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
sql = ( sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
"INNER JOIN current_state_events as c " "INNER JOIN current_state_events as c "
@ -211,8 +212,8 @@ class RoomMemberStore(SQLBaseStore):
"GROUP BY m.room_id HAVING COUNT(m.room_id) = ?" "GROUP BY m.room_id HAVING COUNT(m.room_id) = ?"
) % {"clause": user_list_clause} ) % {"clause": user_list_clause}
args = user_list args = list(user_id_list)
args.append(len(user_list)) args.append(len(user_id_list))
rows = yield self._execute(None, sql, *args) rows = yield self._execute(None, sql, *args)

View File

@ -286,7 +286,7 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self._db_pool.runInteraction(self._get_room_events_max_id_txn) return self.runInteraction(self._get_room_events_max_id_txn)
def _get_room_events_max_id_txn(self, txn): def _get_room_events_max_id_txn(self, txn):
txn.execute( txn.execute(

View File

@ -41,7 +41,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict) this transaction or a 2-tuple of (int, dict)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_received_txn_response, transaction_id, origin self._get_received_txn_response, transaction_id, origin
) )
@ -72,7 +72,7 @@ class TransactionStore(SQLBaseStore):
response_json (str) response_json (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._set_received_txn_response, self._set_received_txn_response,
transaction_id, origin, code, response_dict transaction_id, origin, code, response_dict
) )
@ -104,7 +104,7 @@ class TransactionStore(SQLBaseStore):
list: A list of previous transaction ids. list: A list of previous transaction ids.
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._prep_send_transaction, self._prep_send_transaction,
transaction_id, destination, ts, pdu_list transaction_id, destination, ts, pdu_list
) )
@ -159,7 +159,7 @@ class TransactionStore(SQLBaseStore):
code (int) code (int)
response_json (str) response_json (str)
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._delivered_txn, self._delivered_txn,
transaction_id, destination, code, response_dict transaction_id, destination, code, response_dict
) )
@ -184,7 +184,7 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
list: A list of `ReceivedTransactionsTable.EntryType` list: A list of `ReceivedTransactionsTable.EntryType`
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_transactions_after, transaction_id, destination self._get_transactions_after, transaction_id, destination
) )
@ -214,7 +214,7 @@ class TransactionStore(SQLBaseStore):
Returns Returns
list: A list of PduTuple list: A list of PduTuple
""" """
return self._db_pool.runInteraction( return self.runInteraction(
self._get_pdus_after_transaction, self._get_pdus_after_transaction,
transaction_id, destination transaction_id, destination
) )

View File

@ -24,6 +24,8 @@ from synapse.http.client import HttpClient
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.storage.directory import RoomAliasMapping from synapse.storage.directory import RoomAliasMapping
from tests.utils import SQLiteMemoryDbPool
class DirectoryHandlers(object): class DirectoryHandlers(object):
def __init__(self, hs): def __init__(self, hs):
@ -33,6 +35,7 @@ class DirectoryHandlers(object):
class DirectoryTestCase(unittest.TestCase): class DirectoryTestCase(unittest.TestCase):
""" Tests the directory service. """ """ Tests the directory service. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
@ -43,11 +46,11 @@ class DirectoryTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_federation.register_query_handler = register_query_handler
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
datastore=Mock(spec=[ db_pool=db_pool,
"get_association_from_room_alias",
"get_joined_hosts_for_room",
]),
http_client=None, http_client=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
@ -56,20 +59,16 @@ class DirectoryTestCase(unittest.TestCase):
self.handler = hs.get_handlers().directory_handler self.handler = hs.get_handlers().directory_handler
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def hosts(room_id):
return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts
self.my_room = hs.parse_roomalias("#my-room:test") self.my_room = hs.parse_roomalias("#my-room:test")
self.your_room = hs.parse_roomalias("#your-room:test")
self.remote_room = hs.parse_roomalias("#another:remote") self.remote_room = hs.parse_roomalias("#another:remote")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_local_association(self): def test_get_local_association(self):
mocked_get = self.datastore.get_association_from_room_alias yield self.store.create_room_alias_association(
mocked_get.return_value = defer.succeed( self.my_room, "!8765qwer:test", ["test"]
RoomAliasMapping("!8765qwer:test", "#my-room:test", ["test"])
) )
result = yield self.handler.get_association(self.my_room) result = yield self.handler.get_association(self.my_room)
@ -102,9 +101,8 @@ class DirectoryTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
mocked_get = self.datastore.get_association_from_room_alias yield self.store.create_room_alias_association(
mocked_get.return_value = defer.succeed( self.your_room, "!8765asdf:test", ["test"]
RoomAliasMapping("!8765asdf:test", "#your-room:test", ["test"])
) )
response = yield self.query_handlers["directory"]( response = yield self.query_handlers["directory"](

View File

@ -20,7 +20,9 @@ from twisted.internet import defer, reactor
from mock import Mock, call, ANY from mock import Mock, call, ANY
import json import json
from ..utils import MockHttpResource, MockClock, DeferredMockCallable from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
@ -60,30 +62,21 @@ class JustPresenceHandlers(object):
class PresenceStateTestCase(unittest.TestCase): class PresenceStateTestCase(unittest.TestCase):
""" Tests presence management. """ """ Tests presence management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
clock=MockClock(), clock=MockClock(),
db_pool=None, db_pool=db_pool,
datastore=Mock(spec=[ handlers=None,
"get_presence_state", resource_for_federation=Mock(),
"set_presence_state", http_client=None,
"add_presence_list_pending", )
"set_presence_list_accepted",
]),
handlers=None,
resource_for_federation=Mock(),
http_client=None,
)
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def is_presence_visible(observed_localpart, observer_userid):
allow = (observed_localpart == "apple" and
observer_userid == "@banana:test"
)
return defer.succeed(allow)
self.datastore.is_presence_visible = is_presence_visible
# Mock the RoomMemberHandler # Mock the RoomMemberHandler
room_member_handler = Mock(spec=[]) room_member_handler = Mock(spec=[])
@ -94,6 +87,11 @@ class PresenceStateTestCase(unittest.TestCase):
self.u_banana = hs.parse_userid("@banana:test") self.u_banana = hs.parse_userid("@banana:test")
self.u_clementine = hs.parse_userid("@clementine:test") self.u_clementine = hs.parse_userid("@clementine:test")
yield self.store.create_presence(self.u_apple.localpart)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": ONLINE, "status_msg": "Online"}
)
self.handler = hs.get_handlers().presence_handler self.handler = hs.get_handlers().presence_handler
self.room_members = [] self.room_members = []
@ -117,7 +115,7 @@ class PresenceStateTestCase(unittest.TestCase):
shared = all(map(lambda i: i in room_member_ids, userlist)) shared = all(map(lambda i: i in room_member_ids, userlist))
return defer.succeed(shared) return defer.succeed(shared)
self.datastore.user_rooms_intersect = user_rooms_intersect self.store.user_rooms_intersect = user_rooms_intersect
self.mock_start = Mock() self.mock_start = Mock()
self.mock_stop = Mock() self.mock_stop = Mock()
@ -127,11 +125,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_state(self): def test_get_my_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
state = yield self.handler.get_state( state = yield self.handler.get_state(
target_user=self.u_apple, auth_user=self.u_apple target_user=self.u_apple, auth_user=self.u_apple
) )
@ -140,13 +133,12 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"}, {"presence": ONLINE, "status_msg": "Online"},
state state
) )
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_allowed_state(self): def test_get_allowed_state(self):
mocked_get = self.datastore.get_presence_state yield self.store.allow_presence_visible(
mocked_get.return_value = defer.succeed( observed_localpart=self.u_apple.localpart,
{"state": ONLINE, "status_msg": "Online"} observer_userid=self.u_banana.to_string(),
) )
state = yield self.handler.get_state( state = yield self.handler.get_state(
@ -157,15 +149,9 @@ class PresenceStateTestCase(unittest.TestCase):
{"presence": ONLINE, "status_msg": "Online"}, {"presence": ONLINE, "status_msg": "Online"},
state state
) )
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_same_room_state(self): def test_get_same_room_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
self.room_members = [self.u_apple, self.u_clementine] self.room_members = [self.u_apple, self.u_clementine]
state = yield self.handler.get_state( state = yield self.handler.get_state(
@ -179,11 +165,6 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_disallowed_state(self): def test_get_disallowed_state(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Online"}
)
self.room_members = [] self.room_members = []
yield self.assertFailure( yield self.assertFailure(
@ -195,16 +176,17 @@ class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_state(self): def test_set_my_state(self):
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
yield self.handler.set_state( yield self.handler.set_state(
target_user=self.u_apple, auth_user=self.u_apple, target_user=self.u_apple, auth_user=self.u_apple,
state={"presence": UNAVAILABLE, "status_msg": "Away"}) state={"presence": UNAVAILABLE, "status_msg": "Away"})
mocked_set.assert_called_with("apple", self.assertEquals(
{"state": UNAVAILABLE, "status_msg": "Away"} {"state": UNAVAILABLE,
"status_msg": "Away",
"mtime": 1000000},
(yield self.store.get_presence_state(self.u_apple.localpart))
) )
self.mock_start.assert_called_with(self.u_apple, self.mock_start.assert_called_with(self.u_apple,
state={ state={
"presence": UNAVAILABLE, "presence": UNAVAILABLE,
@ -222,50 +204,34 @@ class PresenceStateTestCase(unittest.TestCase):
class PresenceInvitesTestCase(unittest.TestCase): class PresenceInvitesTestCase(unittest.TestCase):
""" Tests presence management. """ """ Tests presence management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_http_client = Mock(spec=[]) self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable() self.mock_http_client.put_json = DeferredMockCallable()
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
hs = HomeServer("test", db_pool = SQLiteMemoryDbPool()
clock=MockClock(), yield db_pool.prepare()
db_pool=None,
datastore=Mock(spec=[
"has_presence_state",
"allow_presence_visible",
"add_presence_list_pending",
"set_presence_list_accepted",
"get_presence_list",
"del_presence_list",
# Bits that Federation needs hs = HomeServer("test",
"prep_send_transaction", clock=MockClock(),
"delivered_txn", db_pool=db_pool,
"get_received_txn_response", handlers=None,
"set_received_txn_response", resource_for_client=Mock(),
]), resource_for_federation=self.mock_federation_resource,
handlers=None, http_client=self.mock_http_client,
resource_for_client=Mock(), )
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
)
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
def has_presence_state(user_localpart):
return defer.succeed(
user_localpart in ("apple", "banana"))
self.datastore.has_presence_state = has_presence_state
def get_received_txn_response(*args):
return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response
# Some local users to test with # Some local users to test with
self.u_apple = hs.parse_userid("@apple:test") self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test") self.u_banana = hs.parse_userid("@banana:test")
yield self.store.create_presence(self.u_apple.localpart)
yield self.store.create_presence(self.u_banana.localpart)
# ID of a local user that does not exist # ID of a local user that does not exist
self.u_durian = hs.parse_userid("@durian:test") self.u_durian = hs.parse_userid("@durian:test")
@ -288,12 +254,16 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_banana) observer_user=self.u_apple, observed_user=self.u_banana)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@banana:test") [{"observed_user_id": "@banana:test", "accepted": 1}],
self.datastore.allow_presence_visible.assert_called_with( (yield self.store.get_presence_list(self.u_apple.localpart))
"banana", "@apple:test") )
self.datastore.set_presence_list_accepted.assert_called_with( self.assertTrue(
"apple", "@banana:test") (yield self.store.is_presence_visible(
observed_localpart=self.u_banana.localpart,
observer_userid=self.u_apple.to_string(),
))
)
self.mock_start.assert_called_with( self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_banana) self.u_apple, target_user=self.u_banana)
@ -303,10 +273,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_durian) observer_user=self.u_apple, observed_user=self.u_durian)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@durian:test") [],
self.datastore.del_presence_list.assert_called_with( (yield self.store.get_presence_list(self.u_apple.localpart))
"apple", "@durian:test") )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_remote(self): def test_invite_remote(self):
@ -328,8 +298,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_user=self.u_apple, observed_user=self.u_cabbage)
self.datastore.add_presence_list_pending.assert_called_with( self.assertEquals(
"apple", "@cabbage:elsewhere") [{"observed_user_id": "@cabbage:elsewhere", "accepted": 0}],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
yield put_json.await_calls() yield put_json.await_calls()
@ -362,8 +334,12 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.allow_presence_visible.assert_called_with( self.assertTrue(
"apple", "@cabbage:elsewhere") (yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_cabbage.to_string(),
))
)
yield put_json.await_calls() yield put_json.await_calls()
@ -398,6 +374,11 @@ class PresenceInvitesTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_accepted_remote(self): def test_accepted_remote(self):
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_accept", _make_edu_json("elsewhere", "m.presence_accept",
@ -408,14 +389,21 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.set_presence_list_accepted.assert_called_with( self.assertEquals(
"apple", "@cabbage:elsewhere") [{"observed_user_id": "@cabbage:elsewhere", "accepted": 1}],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
self.mock_start.assert_called_with( self.mock_start.assert_called_with(
self.u_apple, target_user=self.u_cabbage) self.u_apple, target_user=self.u_cabbage)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_denied_remote(self): def test_denied_remote(self):
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid="@eggplant:elsewhere",
)
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_deny", _make_edu_json("elsewhere", "m.presence_deny",
@ -426,32 +414,65 @@ class PresenceInvitesTestCase(unittest.TestCase):
) )
) )
self.datastore.del_presence_list.assert_called_with( self.assertEquals(
"apple", "@eggplant:elsewhere") [],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_drop_local(self): def test_drop_local(self):
yield self.handler.drop( yield self.store.add_presence_list_pending(
observer_user=self.u_apple, observed_user=self.u_banana) observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.datastore.del_presence_list.assert_called_with( yield self.handler.drop(
"apple", "@banana:test") observer_user=self.u_apple,
observed_user=self.u_banana,
)
self.assertEquals(
[],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
self.mock_stop.assert_called_with( self.mock_stop.assert_called_with(
self.u_apple, target_user=self.u_banana) self.u_apple, target_user=self.u_banana)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_drop_remote(self): def test_drop_remote(self):
yield self.handler.drop( yield self.store.add_presence_list_pending(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_cabbage.to_string(),
)
self.datastore.del_presence_list.assert_called_with( yield self.handler.drop(
"apple", "@cabbage:elsewhere") observer_user=self.u_apple,
observed_user=self.u_cabbage,
)
self.assertEquals(
[],
(yield self.store.get_presence_list(self.u_apple.localpart))
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_presence_list(self): def test_get_presence_list(self):
self.datastore.get_presence_list.return_value = defer.succeed( yield self.store.add_presence_list_pending(
[{"observed_user_id": "@banana:test"}] observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
) )
presence = yield self.handler.get_presence_list( presence = yield self.handler.get_presence_list(
@ -459,29 +480,10 @@ class PresenceInvitesTestCase(unittest.TestCase):
self.assertEquals([ self.assertEquals([
{"observed_user": self.u_banana, {"observed_user": self.u_banana,
"presence": OFFLINE}, "presence": OFFLINE,
"accepted": 1},
], presence) ], presence)
self.datastore.get_presence_list.assert_called_with("apple",
accepted=None
)
self.datastore.get_presence_list.return_value = defer.succeed(
[{"observed_user_id": "@banana:test"}]
)
presence = yield self.handler.get_presence_list(
observer_user=self.u_apple, accepted=True
)
self.assertEquals([
{"observed_user": self.u_banana,
"presence": OFFLINE},
], presence)
self.datastore.get_presence_list.assert_called_with("apple",
accepted=True)
class PresencePushTestCase(unittest.TestCase): class PresencePushTestCase(unittest.TestCase):
""" Tests steady-state presence status updates. """ Tests steady-state presence status updates.

View File

@ -24,6 +24,8 @@ from synapse.server import HomeServer
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests.utils import SQLiteMemoryDbPool
class ProfileHandlers(object): class ProfileHandlers(object):
def __init__(self, hs): def __init__(self, hs):
@ -33,6 +35,7 @@ class ProfileHandlers(object):
class ProfileTestCase(unittest.TestCase): class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """ """ Tests profile management. """
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
@ -43,63 +46,50 @@ class ProfileTestCase(unittest.TestCase):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler self.mock_federation.register_query_handler = register_query_handler
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=db_pool,
http_client=None, http_client=None,
datastore=Mock(spec=[
"get_profile_displayname",
"set_profile_displayname",
"get_profile_avatar_url",
"set_profile_avatar_url",
"get_rooms_for_user_where_membership_is",
]),
handlers=None, handlers=None,
resource_for_federation=Mock(), resource_for_federation=Mock(),
replication_layer=self.mock_federation, replication_layer=self.mock_federation,
) )
hs.handlers = ProfileHandlers(hs) hs.handlers = ProfileHandlers(hs)
self.datastore = hs.get_datastore() self.store = hs.get_datastore()
self.frank = hs.parse_userid("@1234ABCD:test") self.frank = hs.parse_userid("@1234ABCD:test")
self.bob = hs.parse_userid("@4567:test") self.bob = hs.parse_userid("@4567:test")
self.alice = hs.parse_userid("@alice:remote") self.alice = hs.parse_userid("@alice:remote")
self.handler = hs.get_handlers().profile_handler yield self.store.create_profile(self.frank.localpart)
self.mock_get_joined = ( self.handler = hs.get_handlers().profile_handler
self.datastore.get_rooms_for_user_where_membership_is
)
# TODO(paul): Icky signal declarings.. booo # TODO(paul): Icky signal declarings.. booo
hs.get_distributor().declare("changed_presencelike_data") hs.get_distributor().declare("changed_presencelike_data")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
mocked_get = self.datastore.get_profile_displayname yield self.store.set_profile_displayname(
mocked_get.return_value = defer.succeed("Frank") self.frank.localpart, "Frank"
)
displayname = yield self.handler.get_displayname(self.frank) displayname = yield self.handler.get_displayname(self.frank)
self.assertEquals("Frank", displayname) self.assertEquals("Frank", displayname)
mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
mocked_set = self.datastore.set_profile_displayname
mocked_set.return_value = defer.succeed(())
self.mock_get_joined.return_value = defer.succeed([])
yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.") yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.")
self.mock_get_joined.assert_called_once_with( self.assertEquals(
self.frank.to_string(), (yield self.store.get_profile_displayname(self.frank.localpart)),
[Membership.JOIN] "Frank Jr."
) )
mocked_set.assert_called_with("1234ABCD", "Frank Jr.")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.") d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.")
@ -123,40 +113,31 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
mocked_get = self.datastore.get_profile_displayname yield self.store.create_profile("caroline")
mocked_get.return_value = defer.succeed("Caroline") yield self.store.set_profile_displayname("caroline", "Caroline")
response = yield self.query_handlers["profile"]( response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"} {"user_id": "@caroline:test", "field": "displayname"}
) )
self.assertEquals({"displayname": "Caroline"}, response) self.assertEquals({"displayname": "Caroline"}, response)
mocked_get.assert_called_with("caroline")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
mocked_get = self.datastore.get_profile_avatar_url yield self.store.set_profile_avatar_url(
mocked_get.return_value = defer.succeed("http://my.server/me.png") self.frank.localpart, "http://my.server/me.png"
)
avatar_url = yield self.handler.get_avatar_url(self.frank) avatar_url = yield self.handler.get_avatar_url(self.frank)
self.assertEquals("http://my.server/me.png", avatar_url) self.assertEquals("http://my.server/me.png", avatar_url)
mocked_get.assert_called_with("1234ABCD")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
mocked_set = self.datastore.set_profile_avatar_url
mocked_set.return_value = defer.succeed(())
self.mock_get_joined.return_value = defer.succeed([])
yield self.handler.set_avatar_url(self.frank, self.frank, yield self.handler.set_avatar_url(self.frank, self.frank,
"http://my.server/pic.gif") "http://my.server/pic.gif")
self.mock_get_joined.assert_called_once_with( self.assertEquals(
self.frank.to_string(), (yield self.store.get_profile_avatar_url(self.frank.localpart)),
[Membership.JOIN] "http://my.server/pic.gif"
) )
mocked_set.assert_called_with("1234ABCD", "http://my.server/pic.gif")

View File

@ -0,0 +1,5 @@
synapse/storage/feedback.py
synapse/storage/keys.py
synapse/storage/pdu.py
synapse/storage/stream.py
synapse/storage/transactions.py

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.directory import DirectoryStore
from tests.utils import SQLiteMemoryDbPool
class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = DirectoryStore(hs)
self.room = hs.parse_roomid("!abcde:test")
self.alias = hs.parse_roomalias("#my-room:test")
@defer.inlineCallbacks
def test_room_to_alias(self):
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"],
)
self.assertEquals(
["#my-room:test"],
(yield self.store.get_aliases_for_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_alias_to_room(self):
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"],
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(),
"servers": ["test"]},
(yield self.store.get_association_from_room_alias(self.alias))
)

View File

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.presence import PresenceStore
from tests.utils import SQLiteMemoryDbPool, MockClock
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
clock=MockClock(),
db_pool=db_pool,
)
self.store = PresenceStore(hs)
self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test")
@defer.inlineCallbacks
def test_state(self):
yield self.store.create_presence(
self.u_apple.localpart
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": None, "status_msg": None, "mtime": None}, state
)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": "online", "status_msg": "Here"}
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": "online", "status_msg": "Here", "mtime": 1000000}, state
)
@defer.inlineCallbacks
def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.allow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertTrue((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
yield self.store.disallow_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)
self.assertFalse((yield self.store.is_presence_visible(
observed_localpart=self.u_apple.localpart,
observer_userid=self.u_banana.to_string(),
)))
@defer.inlineCallbacks
def test_presence_list(self):
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 0}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)
yield self.store.del_presence_list(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
))
)
self.assertEquals(
[],
(yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart,
accepted=True,
))
)

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.profile import ProfileStore
from tests.utils import SQLiteMemoryDbPool
class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = ProfileStore(hs)
self.u_frank = hs.parse_userid("@frank:test")
@defer.inlineCallbacks
def test_displayname(self):
yield self.store.create_profile(
self.u_frank.localpart
)
yield self.store.set_profile_displayname(
self.u_frank.localpart, "Frank"
)
self.assertEquals(
"Frank",
(yield self.store.get_profile_displayname(self.u_frank.localpart))
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield self.store.create_profile(
self.u_frank.localpart
)
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
self.assertEquals(
"http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart))
)

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage.registration import RegistrationStore
from tests.utils import SQLiteMemoryDbPool
class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.store = RegistrationStore(hs)
self.user_id = "@my-user:test"
self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
"BcDeFgHiJkLmNoPqRsTuVwXyZa"]
self.pwhash = "{xx1}123456789"
@defer.inlineCallbacks
def test_register(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
self.assertEquals(
# TODO(paul): Surely this field should be 'user_id', not 'name'
# Additionally surely it shouldn't come in a 1-element list
[{"name": self.user_id, "password_hash": self.pwhash}],
(yield self.store.get_user_by_id(self.user_id))
)
self.assertEquals(
self.user_id,
(yield self.store.get_user_by_token(self.tokens[0]))
)
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals(
self.user_id,
(yield self.store.get_user_by_token(self.tokens[1]))
)

176
tests/storage/test_room.py Normal file
View File

@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.events.room import (
RoomNameEvent, RoomTopicEvent
)
from tests.utils import SQLiteMemoryDbPool
class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
self.room = hs.parse_roomid("!abcde:test")
self.alias = hs.parse_roomalias("#a-room-name:test")
self.u_creator = hs.parse_userid("@creator:test")
yield self.store.store_room(self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
is_public=True
)
@defer.inlineCallbacks
def test_get_room(self):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True},
(yield self.store.get_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_store_room_config(self):
yield self.store.store_room_config(self.room.to_string(),
visibility=False
)
self.assertObjectHasAttributes(
{"is_public": False},
(yield self.store.get_room(self.room.to_string()))
)
@defer.inlineCallbacks
def test_get_rooms(self):
# get_rooms does an INNER JOIN on the room_aliases table :(
rooms = yield self.store.get_rooms(is_public=True)
# Should be empty before we add the alias
self.assertEquals([], rooms)
yield self.store.create_room_alias_association(
room_alias=self.alias,
room_id=self.room.to_string(),
servers=["test"]
)
rooms = yield self.store.get_rooms(is_public=True)
self.assertEquals(1, len(rooms))
self.assertEquals({
"name": None,
"room_id": self.room.to_string(),
"topic": None,
"aliases": [self.alias.to_string()],
}, rooms[0])
class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory();
self.room = hs.parse_roomid("!abcde:test")
yield self.store.store_room(self.room.to_string(),
room_creator_user_id="@creator:text",
is_public=True
)
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield self.store.persist_event(
self.event_factory.create_event(
room_id=self.room.to_string(),
**kwargs
)
)
@defer.inlineCallbacks
def test_room_name(self):
name = u"A-Room-Name"
yield self.inject_room_event(
etype=RoomNameEvent.TYPE,
name=name,
content={"name": name},
depth=1,
)
state = yield self.store.get_current_state(
room_id=self.room.to_string()
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
{"type": "m.room.name",
"room_id": self.room.to_string(),
"name": name},
state[0]
)
@defer.inlineCallbacks
def test_room_name(self):
topic = u"A place for things"
yield self.inject_room_event(
etype=RoomTopicEvent.TYPE,
topic=topic,
content={"topic": topic},
depth=1,
)
state = yield self.store.get_current_state(
room_id=self.room.to_string()
)
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
{"type": "m.room.topic",
"room_id": self.room.to_string(),
"topic": topic},
state[0]
)
# Not testing the various 'level' methods for now because there's lots
# of them and need coalescing; see JIRA SPEC-11

View File

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from tests.utils import SQLiteMemoryDbPool
class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
# User elsewhere on another host
self.u_charlie = hs.parse_userid("@charlie:elsewhere")
self.room = hs.parse_roomid("!abc123:test")
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
# Have to create a join event using the eventfactory
yield self.store.persist_event(
self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=1,
)
)
@defer.inlineCallbacks
def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
Membership.JOIN,
(yield self.store.get_room_member(
user_id=self.u_alice.to_string(),
room_id=self.room.to_string(),
)).membership
)
self.assertEquals(
[self.u_alice.to_string()],
[m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)]
)
self.assertEquals(
[self.room.to_string()],
[m.room_id for m in (
yield self.store.get_rooms_for_user_where_membership_is(
self.u_alice.to_string(), [Membership.JOIN]
))
]
)
self.assertFalse(
(yield self.store.user_rooms_intersect(
[self.u_alice.to_string(), self.u_bob.to_string()]
))
)
@defer.inlineCallbacks
def test_two_members(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
{self.u_alice.to_string(), self.u_bob.to_string()},
{m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)}
)
self.assertTrue(
(yield self.store.user_rooms_intersect(
[self.u_alice.to_string(), self.u_bob.to_string()]
))
)
@defer.inlineCallbacks
def test_room_hosts(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should still have just one host after second join from it
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should now have two hosts after join from other host
yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
self.assertEquals(
{"test", "elsewhere"},
set((yield
self.store.get_joined_hosts_for_room(self.room.to_string())
))
)
# Should still have both hosts
yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
self.assertEquals(
{"test", "elsewhere"},
set((yield
self.store.get_joined_hosts_for_room(self.room.to_string())
))
)
# Should have only one host after other leaves
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
self.assertEquals(
["test"],
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)

View File

@ -71,6 +71,17 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level) logging.getLogger().setLevel(level)
return orig() return orig()
def assertObjectHasAttributes(self, attrs, obj):
"""Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEquals."""
for (key, value) in attrs.items():
if not hasattr(obj, key):
raise AssertionError("Expected obj to have a '.%s'" % key)
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
raise (type(e))(e.message + " for '.%s'" % key)
def DEBUG(target): def DEBUG(target):
"""A decorator to set the .loglevel attribute to logging.DEBUG. """A decorator to set the .loglevel attribute to logging.DEBUG.

View File

@ -16,12 +16,14 @@
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.storage import prepare_database
from synapse.api.events.room import ( from synapse.api.events.room import (
RoomMemberEvent, MessageEvent RoomMemberEvent, MessageEvent
) )
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.enterprise.adbapi import ConnectionPool
from collections import namedtuple from collections import namedtuple
from mock import patch, Mock from mock import patch, Mock
@ -120,6 +122,18 @@ class MockClock(object):
self.now += secs self.now += secs
class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self):
super(SQLiteMemoryDbPool, self).__init__(
"sqlite3", ":memory:",
cp_min=1,
cp_max=1,
)
def prepare(self):
return self.runWithConnection(prepare_database)
class MemoryDataStore(object): class MemoryDataStore(object):
Room = namedtuple( Room = namedtuple(