Remove get_state_for_room function from federation handler

This commit is contained in:
Mark Haines 2014-08-28 15:32:30 +01:00
parent bddc1d9fff
commit 7b079a26a5
3 changed files with 29 additions and 17 deletions

View File

@ -84,12 +84,6 @@ class FederationHandler(BaseHandler):
yield self.replication_layer.send_pdu(pdu) yield self.replication_layer.send_pdu(pdu)
@log_function
def get_state_for_room(self, destination, room_id):
return self.replication_layer.get_state_for_context(
destination, room_id
)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, pdu, backfilled): def on_receive_pdu(self, pdu, backfilled):
@ -139,7 +133,7 @@ class FederationHandler(BaseHandler):
yield self.hs.get_handlers().room_member_handler.change_membership( yield self.hs.get_handlers().room_member_handler.change_membership(
new_event, new_event,
True do_auth=True
) )
else: else:
@ -151,8 +145,8 @@ class FederationHandler(BaseHandler):
if not room: if not room:
# Huh, let's try and get the current state # Huh, let's try and get the current state
try: try:
yield self.get_state_for_room( yield self.replication_layer.get_state_for_context(
event.origin, event.room_id origin, event.room_id
) )
hosts = yield self.store.get_joined_hosts_for_room( hosts = yield self.store.get_joined_hosts_for_room(
@ -161,9 +155,9 @@ class FederationHandler(BaseHandler):
if self.hs.hostname in hosts: if self.hs.hostname in hosts:
try: try:
yield self.store.store_room( yield self.store.store_room(
event.room_id, room_id=event.room_id,
"", room_creator_user_id="",
is_public=False is_public=False,
) )
except: except:
pass pass
@ -209,7 +203,9 @@ class FederationHandler(BaseHandler):
# First get current state to see if we are already joined. # First get current state to see if we are already joined.
try: try:
yield self.get_state_for_room(target_host, room_id) yield self.replication_layer.get_state_for_context(
target_host, room_id
)
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts: if self.hs.hostname in hosts:
@ -239,8 +235,8 @@ class FederationHandler(BaseHandler):
try: try:
yield self.store.store_room( yield self.store.store_room(
room_id, room_id=room_id,
"", room_creator_user_id="",
is_public=False is_public=False
) )
except: except:

View File

@ -28,6 +28,8 @@ from mock import NonCallableMock, ANY
import logging import logging
from ..utils import get_mock_call_args
logging.getLogger().addHandler(logging.NullHandler()) logging.getLogger().addHandler(logging.NullHandler())
@ -99,9 +101,13 @@ class FederationTestCase(unittest.TestCase):
mem_handler = self.handlers.room_member_handler mem_handler = self.handlers.room_member_handler
self.assertEquals(1, mem_handler.change_membership.call_count) self.assertEquals(1, mem_handler.change_membership.call_count)
self.assertEquals(True, mem_handler.change_membership.call_args[0][1]) call_args = get_mock_call_args(
lambda event, do_auth: None,
mem_handler.change_membership
)
self.assertEquals(True, call_args["do_auth"])
new_event = mem_handler.change_membership.call_args[0][0] new_event = call_args["event"]
self.assertEquals(RoomMemberEvent.TYPE, new_event.type) self.assertEquals(RoomMemberEvent.TYPE, new_event.type)
self.assertEquals(room_id, new_event.room_id) self.assertEquals(room_id, new_event.room_id)
self.assertEquals(user_id, new_event.state_key) self.assertEquals(user_id, new_event.state_key)

View File

@ -28,6 +28,16 @@ from mock import patch, Mock
import json import json
import urlparse import urlparse
from inspect import getcallargs
def get_mock_call_args(pattern_func, mock_func):
""" Return the arguments the mock function was called with interpreted
by the pattern functions argument list.
"""
invoked_args, invoked_kargs = mock_func.call_args
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):