Supply auth_chain along with current state in '/state/', fetch auth events from a remote server if we are missing some of them
This commit is contained in:
parent
dbe77ec79a
commit
041ac476a5
|
@ -256,31 +256,35 @@ class ReplicationLayer(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_state_for_context(self, destination, context, event_id=None):
|
||||
def get_state_for_context(self, destination, context, event_id):
|
||||
"""Requests all of the `current` state PDUs for a given context from
|
||||
a remote home server.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
context (str): The context we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_context_state(
|
||||
result = yield self.transport_layer.get_context_state(
|
||||
destination,
|
||||
context,
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
pdus = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in transaction.pdus
|
||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
defer.returnValue(pdus)
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
defer.returnValue((pdus, auth_chain))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -383,10 +387,16 @@ class ReplicationLayer(object):
|
|||
context,
|
||||
event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
defer.returnValue((200, {
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -573,6 +583,8 @@ class ReplicationLayer(object):
|
|||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
|
||||
# We need to make sure we have all the auth events.
|
||||
# for e_id, _ in pdu.auth_events:
|
||||
# exists = yield self._get_persisted_pdu(
|
||||
|
@ -645,7 +657,7 @@ class ReplicationLayer(object):
|
|||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
state = yield self.get_state_for_context(
|
||||
state, auth_chain = yield self.get_state_for_context(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
|
||||
|
@ -655,6 +667,7 @@ class ReplicationLayer(object):
|
|||
pdu,
|
||||
backfilled=backfilled,
|
||||
state=state,
|
||||
auth_chain=auth_chain,
|
||||
)
|
||||
else:
|
||||
ret = None
|
||||
|
|
|
@ -95,7 +95,8 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None):
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None,
|
||||
auth_chain=None):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it through the StateHandler.
|
||||
"""
|
||||
|
@ -150,35 +151,35 @@ class FederationHandler(BaseHandler):
|
|||
if not is_in_room and not event.internal_metadata.outlier:
|
||||
logger.debug("Got event for room we're not in.")
|
||||
|
||||
replication_layer = self.replication_layer
|
||||
auth_chain = yield replication_layer.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
replication = self.replication_layer
|
||||
|
||||
if not state:
|
||||
state, auth_chain = yield replication.get_state_for_context(
|
||||
origin, context=event.room_id, event_id=event.event_id,
|
||||
)
|
||||
|
||||
if not auth_chain:
|
||||
auth_chain = yield replication.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
|
||||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
yield self._handle_new_event(e, fetch_auth_from=origin)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
if not state:
|
||||
state = yield replication_layer.get_state_for_context(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
# FIXME: Get auth chain for these state events
|
||||
|
||||
current_state = state
|
||||
|
||||
if state:
|
||||
for e in state:
|
||||
logging.info("A :) %r", e)
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e)
|
||||
|
@ -392,7 +393,7 @@ class FederationHandler(BaseHandler):
|
|||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
yield self._handle_new_event(e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
|
@ -404,8 +405,7 @@ class FederationHandler(BaseHandler):
|
|||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(
|
||||
e,
|
||||
fetch_missing=True
|
||||
e, fetch_auth_from=target_host
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
|
@ -682,7 +682,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None, fetch_missing=True):
|
||||
current_state=None, fetch_auth_from=None):
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||
|
@ -703,11 +703,20 @@ class FederationHandler(BaseHandler):
|
|||
known_ids = set(
|
||||
[s.event_id for s in context.auth_events.values()]
|
||||
)
|
||||
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(
|
||||
e_id, allow_none=True,
|
||||
)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e and fetch_auth_from is not None:
|
||||
# Grab the auth_chain over federation if we are missing
|
||||
# auth events.
|
||||
auth_chain = yield self.replication_layer.get_event_auth(
|
||||
fetch_auth_from, event.event_id, event.room_id
|
||||
)
|
||||
for auth_event in auth_chain:
|
||||
yield self._handle_new_event(auth_event)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e:
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
|
|
|
@ -120,5 +120,5 @@ class Signal(object):
|
|||
results = []
|
||||
for deferred in deferreds:
|
||||
result = yield deferred
|
||||
results.append(results)
|
||||
results.append(result)
|
||||
defer.returnValue(results)
|
||||
|
|
|
@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_auth_chain",
|
||||
])
|
||||
self.mock_persistence.get_received_txn_response.return_value = (
|
||||
defer.succeed(None)
|
||||
|
@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
self.mock_persistence.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
self.mock_persistence.get_auth_chain.return_value = []
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
self.clock = MockClock()
|
||||
|
|
Loading…
Reference in New Issue