Supply auth_chain along with current state in '/state/', fetch auth events from a remote server if we are missing some of them
This commit is contained in:
parent
dbe77ec79a
commit
041ac476a5
|
@ -256,31 +256,35 @@ class ReplicationLayer(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_state_for_context(self, destination, context, event_id=None):
|
def get_state_for_context(self, destination, context, event_id):
|
||||||
"""Requests all of the `current` state PDUs for a given context from
|
"""Requests all of the `current` state PDUs for a given context from
|
||||||
a remote home server.
|
a remote home server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): The remote homeserver to query for the state.
|
destination (str): The remote homeserver to query for the state.
|
||||||
context (str): The context we're interested in.
|
context (str): The context we're interested in.
|
||||||
|
event_id (str): The id of the event we want the state at.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in a list of PDUs.
|
Deferred: Results in a list of PDUs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transaction_data = yield self.transport_layer.get_context_state(
|
result = yield self.transport_layer.get_context_state(
|
||||||
destination,
|
destination,
|
||||||
context,
|
context,
|
||||||
event_id=event_id,
|
event_id=event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
transaction = Transaction(**transaction_data)
|
|
||||||
pdus = [
|
pdus = [
|
||||||
self.event_from_pdu_json(p, outlier=True)
|
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||||
for p in transaction.pdus
|
|
||||||
]
|
]
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
auth_chain = [
|
||||||
|
self.event_from_pdu_json(p, outlier=True)
|
||||||
|
for p in result.get("auth_chain", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
defer.returnValue((pdus, auth_chain))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -383,10 +387,16 @@ class ReplicationLayer(object):
|
||||||
context,
|
context,
|
||||||
event_id,
|
event_id,
|
||||||
)
|
)
|
||||||
|
auth_chain = yield self.store.get_auth_chain(
|
||||||
|
[pdu.event_id for pdu in pdus]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Specify an event")
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
defer.returnValue((200, {
|
||||||
|
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||||
|
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||||
|
}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -573,6 +583,8 @@ class ReplicationLayer(object):
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
|
auth_chain = []
|
||||||
|
|
||||||
# We need to make sure we have all the auth events.
|
# We need to make sure we have all the auth events.
|
||||||
# for e_id, _ in pdu.auth_events:
|
# for e_id, _ in pdu.auth_events:
|
||||||
# exists = yield self._get_persisted_pdu(
|
# exists = yield self._get_persisted_pdu(
|
||||||
|
@ -645,7 +657,7 @@ class ReplicationLayer(object):
|
||||||
"_handle_new_pdu getting state for %s",
|
"_handle_new_pdu getting state for %s",
|
||||||
pdu.room_id
|
pdu.room_id
|
||||||
)
|
)
|
||||||
state = yield self.get_state_for_context(
|
state, auth_chain = yield self.get_state_for_context(
|
||||||
origin, pdu.room_id, pdu.event_id,
|
origin, pdu.room_id, pdu.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -655,6 +667,7 @@ class ReplicationLayer(object):
|
||||||
pdu,
|
pdu,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
state=state,
|
state=state,
|
||||||
|
auth_chain=auth_chain,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret = None
|
ret = None
|
||||||
|
|
|
@ -95,7 +95,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None):
|
def on_receive_pdu(self, origin, pdu, backfilled, state=None,
|
||||||
|
auth_chain=None):
|
||||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||||
do auth checks and put it through the StateHandler.
|
do auth checks and put it through the StateHandler.
|
||||||
"""
|
"""
|
||||||
|
@ -150,35 +151,35 @@ class FederationHandler(BaseHandler):
|
||||||
if not is_in_room and not event.internal_metadata.outlier:
|
if not is_in_room and not event.internal_metadata.outlier:
|
||||||
logger.debug("Got event for room we're not in.")
|
logger.debug("Got event for room we're not in.")
|
||||||
|
|
||||||
replication_layer = self.replication_layer
|
replication = self.replication_layer
|
||||||
auth_chain = yield replication_layer.get_event_auth(
|
|
||||||
origin,
|
if not state:
|
||||||
context=event.room_id,
|
state, auth_chain = yield replication.get_state_for_context(
|
||||||
event_id=event.event_id,
|
origin, context=event.room_id, event_id=event.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not auth_chain:
|
||||||
|
auth_chain = yield replication.get_event_auth(
|
||||||
|
origin,
|
||||||
|
context=event.room_id,
|
||||||
|
event_id=event.event_id,
|
||||||
|
)
|
||||||
|
|
||||||
for e in auth_chain:
|
for e in auth_chain:
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(e, fetch_missing=False)
|
yield self._handle_new_event(e, fetch_auth_from=origin)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to handle auth event %s",
|
"Failed to handle auth event %s",
|
||||||
e.event_id,
|
e.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not state:
|
|
||||||
state = yield replication_layer.get_state_for_context(
|
|
||||||
origin,
|
|
||||||
context=event.room_id,
|
|
||||||
event_id=event.event_id,
|
|
||||||
)
|
|
||||||
# FIXME: Get auth chain for these state events
|
|
||||||
|
|
||||||
current_state = state
|
current_state = state
|
||||||
|
|
||||||
if state:
|
if state:
|
||||||
for e in state:
|
for e in state:
|
||||||
|
logging.info("A :) %r", e)
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(e)
|
yield self._handle_new_event(e)
|
||||||
|
@ -392,7 +393,7 @@ class FederationHandler(BaseHandler):
|
||||||
for e in auth_chain:
|
for e in auth_chain:
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(e, fetch_missing=False)
|
yield self._handle_new_event(e)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to handle auth event %s",
|
"Failed to handle auth event %s",
|
||||||
|
@ -404,8 +405,7 @@ class FederationHandler(BaseHandler):
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
try:
|
||||||
yield self._handle_new_event(
|
yield self._handle_new_event(
|
||||||
e,
|
e, fetch_auth_from=target_host
|
||||||
fetch_missing=True
|
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
@ -682,7 +682,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||||
current_state=None, fetch_missing=True):
|
current_state=None, fetch_auth_from=None):
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||||
|
@ -703,11 +703,20 @@ class FederationHandler(BaseHandler):
|
||||||
known_ids = set(
|
known_ids = set(
|
||||||
[s.event_id for s in context.auth_events.values()]
|
[s.event_id for s in context.auth_events.values()]
|
||||||
)
|
)
|
||||||
|
|
||||||
for e_id, _ in event.auth_events:
|
for e_id, _ in event.auth_events:
|
||||||
if e_id not in known_ids:
|
if e_id not in known_ids:
|
||||||
e = yield self.store.get_event(
|
e = yield self.store.get_event(e_id, allow_none=True)
|
||||||
e_id, allow_none=True,
|
|
||||||
)
|
if not e and fetch_auth_from is not None:
|
||||||
|
# Grab the auth_chain over federation if we are missing
|
||||||
|
# auth events.
|
||||||
|
auth_chain = yield self.replication_layer.get_event_auth(
|
||||||
|
fetch_auth_from, event.event_id, event.room_id
|
||||||
|
)
|
||||||
|
for auth_event in auth_chain:
|
||||||
|
yield self._handle_new_event(auth_event)
|
||||||
|
e = yield self.store.get_event(e_id, allow_none=True)
|
||||||
|
|
||||||
if not e:
|
if not e:
|
||||||
# TODO: Do some conflict res to make sure that we're
|
# TODO: Do some conflict res to make sure that we're
|
||||||
|
|
|
@ -120,5 +120,5 @@ class Signal(object):
|
||||||
results = []
|
results = []
|
||||||
for deferred in deferreds:
|
for deferred in deferreds:
|
||||||
result = yield deferred
|
result = yield deferred
|
||||||
results.append(results)
|
results.append(result)
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
|
@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
"get_received_txn_response",
|
"get_received_txn_response",
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
|
"get_auth_chain",
|
||||||
])
|
])
|
||||||
self.mock_persistence.get_received_txn_response.return_value = (
|
self.mock_persistence.get_received_txn_response.return_value = (
|
||||||
defer.succeed(None)
|
defer.succeed(None)
|
||||||
|
@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
self.mock_persistence.get_destination_retry_timings.return_value = (
|
self.mock_persistence.get_destination_retry_timings.return_value = (
|
||||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||||
)
|
)
|
||||||
|
self.mock_persistence.get_auth_chain.return_value = []
|
||||||
self.mock_config = Mock()
|
self.mock_config = Mock()
|
||||||
self.mock_config.signing_key = [MockKey()]
|
self.mock_config.signing_key = [MockKey()]
|
||||||
self.clock = MockClock()
|
self.clock = MockClock()
|
||||||
|
|
Loading…
Reference in New Issue