Rename PDU fields to match that of events.

This commit is contained in:
Erik Johnston 2014-11-03 13:06:58 +00:00
parent d59aa6af25
commit ad6eacb3e9
6 changed files with 80 additions and 316 deletions

View File

@ -32,7 +32,7 @@ def prune_event(event):
def prune_pdu(pdu): def prune_pdu(pdu):
"""Removes keys that contain unrestricted and non-essential data from a PDU """Removes keys that contain unrestricted and non-essential data from a PDU
""" """
return _prune_event_or_pdu(pdu.pdu_type, pdu) return _prune_event_or_pdu(pdu.type, pdu)
def _prune_event_or_pdu(event_type, event): def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields. # Remove all extraneous fields.

View File

@ -31,39 +31,16 @@ class PduCodec(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
def encode_event_id(self, local, domain):
return local
def decode_event_id(self, event_id):
e_id = self.hs.parse_eventid(event_id)
return event_id, e_id.domain
def event_from_pdu(self, pdu): def event_from_pdu(self, pdu):
kwargs = {} kwargs = {}
kwargs["event_id"] = self.encode_event_id(pdu.pdu_id, pdu.origin) kwargs["etype"] = pdu.type
kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type
kwargs["prev_events"] = [
(self.encode_event_id(i, o), s)
for i, o, s in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
kwargs["prev_state"] = self.encode_event_id(
pdu.prev_state_id, pdu.prev_state_origin
)
kwargs.update({ kwargs.update({
k: v k: v
for k, v in pdu.get_full_dict().items() for k, v in pdu.get_full_dict().items()
if k not in [ if k not in [
"pdu_id", "type",
"context",
"pdu_type",
"prev_pdus",
"prev_state_id",
"prev_state_origin",
] ]
}) })
@ -72,33 +49,12 @@ class PduCodec(object):
def pdu_from_event(self, event): def pdu_from_event(self, event):
d = event.get_full_dict() d = event.get_full_dict()
d["pdu_id"], d["origin"] = self.decode_event_id(
event.event_id
)
d["context"] = event.room_id
d["pdu_type"] = event.type
if hasattr(event, "prev_events"):
def f(e, s):
i, o = self.decode_event_id(e)
return i, o, s
d["prev_pdus"] = [
f(e, s)
for e, s in event.prev_events
]
if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = (
self.decode_event_id(event.prev_state)
)
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
d["is_state"] = True d["is_state"] = True
kwargs = copy.deepcopy(event.unrecognized_keys) kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({ kwargs.update({
k: v for k, v in d.items() k: v for k, v in d.items()
if k not in ["event_id", "room_id", "type", "prev_events"]
}) })
if "origin_server_ts" not in kwargs: if "origin_server_ts" not in kwargs:

View File

@ -111,14 +111,6 @@ class ReplicationLayer(object):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others. home server that should be transmitted to others.
This will fill out various attributes on the PDU object, e.g. the
`prev_pdus` key.
*Note:* The home server should always call `send_pdu` even if it knows
that it does not need to be replicated to other home servers. This is
in case e.g. someone else joins via a remote home server and then
backfills.
TODO: Figure out when we should actually resolve the deferred. TODO: Figure out when we should actually resolve the deferred.
Args: Args:
@ -131,18 +123,12 @@ class ReplicationLayer(object):
order = self._order order = self._order
self._order += 1 self._order += 1
logger.debug("[%s] Persisting PDU", pdu.pdu_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# Save *before* trying to send
# yield self.store.persist_event(pdu=pdu)
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order) self._transaction_queue.enqueue_pdu(pdu, order)
logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id) logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.event_id)
@log_function @log_function
def send_edu(self, destination, edu_type, content): def send_edu(self, destination, edu_type, content):
@ -215,7 +201,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False): def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home """Requests the PDU with given origin and ID from the remote home
server. server.
@ -224,7 +210,7 @@ class ReplicationLayer(object):
Args: Args:
destination (str): Which home server to query destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu. pdu_origin (str): The home server that originally sent the pdu.
pdu_id (str) event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`
@ -233,8 +219,9 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU. Deferred: Results in the requested PDU.
""" """
transaction_data = yield self.transport_layer.get_pdu( transaction_data = yield self.transport_layer.get_event(
destination, pdu_origin, pdu_id) destination, event_id
)
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
@ -249,8 +236,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_state_for_context(self, destination, context, pdu_id=None, def get_state_for_context(self, destination, context, event_id=None):
pdu_origin=None):
"""Requests all of the `current` state PDUs for a given context from """Requests all of the `current` state PDUs for a given context from
a remote home server. a remote home server.
@ -263,7 +249,9 @@ class ReplicationLayer(object):
""" """
transaction_data = yield self.transport_layer.get_context_state( transaction_data = yield self.transport_layer.get_context_state(
destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin, destination,
context,
event_id=event_id,
) )
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
@ -352,10 +340,10 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, context, pdu_id, pdu_origin): def on_context_state_request(self, context, event_id):
if pdu_id and pdu_origin: if event_id:
pdus = yield self.handler.get_state_for_pdu( pdus = yield self.handler.get_state_for_pdu(
pdu_id, pdu_origin event_id
) )
else: else:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
@ -370,8 +358,8 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_pdu_request(self, pdu_origin, pdu_id): def on_pdu_request(self, event_id):
pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin) pdu = yield self._get_persisted_pdu(event_id)
if pdu: if pdu:
defer.returnValue( defer.returnValue(
@ -443,9 +431,8 @@ class ReplicationLayer(object):
def send_join(self, destination, pdu): def send_join(self, destination, pdu):
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
destination, destination,
pdu.context, pdu.room_id,
pdu.pdu_id, pdu.event_id,
pdu.origin,
pdu.get_dict(), pdu.get_dict(),
) )
@ -457,13 +444,13 @@ class ReplicationLayer(object):
defer.returnValue(pdus) defer.returnValue(pdus)
@log_function @log_function
def _get_persisted_pdu(self, pdu_id, pdu_origin): def _get_persisted_pdu(self, event_id):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
Returns: Returns:
Deferred: Results in a `Pdu`. Deferred: Results in a `Pdu`.
""" """
return self.handler.get_persisted_pdu(pdu_id, pdu_origin) return self.handler.get_persisted_pdu(event_id)
def _transaction_from_pdus(self, pdu_list): def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for """Returns a new Transaction containing the given PDUs suitable for
@ -487,10 +474,10 @@ class ReplicationLayer(object):
@log_function @log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False): def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) existing = yield self._get_persisted_pdu(pdu.event_id)
if existing and (not existing.outlier or pdu.outlier): if existing and (not existing.outlier or pdu.outlier):
logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin) logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({}) defer.returnValue({})
return return
@ -500,23 +487,22 @@ class ReplicationLayer(object):
if not pdu.outlier: if not pdu.outlier:
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context( min_depth = yield self.handler.get_min_depth_for_context(
pdu.context pdu.room_id
) )
if min_depth and pdu.depth > min_depth: if min_depth and pdu.depth > min_depth:
for pdu_id, origin, hashes in pdu.prev_pdus: for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(pdu_id, origin) exists = yield self._get_persisted_pdu(event_id)
if not exists: if not exists:
logger.debug("Requesting pdu %s %s", pdu_id, origin) logger.debug("Requesting pdu %s", event_id)
try: try:
yield self.get_pdu( yield self.get_pdu(
pdu.origin, pdu.origin,
pdu_id=pdu_id, event_id=event_id,
pdu_origin=origin
) )
logger.debug("Processed pdu %s %s", pdu_id, origin) logger.debug("Processed pdu %s", event_id)
except: except:
# TODO(erikj): Do some more intelligent retries. # TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU") logger.exception("Failed to get PDU")
@ -524,7 +510,7 @@ class ReplicationLayer(object):
# We need to get the state at this event, since we have reached # We need to get the state at this event, since we have reached
# a backward extremity edge. # a backward extremity edge.
state = yield self.get_state_for_context( state = yield self.get_state_for_context(
origin, pdu.context, pdu.pdu_id, pdu.origin, origin, pdu.room_id, pdu.event_id,
) )
# Persist the Pdu, but don't mark it as processed yet. # Persist the Pdu, but don't mark it as processed yet.

View File

@ -72,8 +72,7 @@ class TransportLayer(object):
self.received_handler = None self.received_handler = None
@log_function @log_function
def get_context_state(self, destination, context, pdu_id=None, def get_context_state(self, destination, context, event_id=None):
pdu_origin=None):
""" Requests all state for a given context (i.e. room) from the """ Requests all state for a given context (i.e. room) from the
given server. given server.
@ -91,60 +90,59 @@ class TransportLayer(object):
subpath = "/state/%s/" % context subpath = "/state/%s/" % context
args = {} args = {}
if pdu_id and pdu_origin: if event_id:
args["pdu_id"] = pdu_id args["event_id"] = event_id
args["pdu_origin"] = pdu_origin
return self._do_request_for_transaction( return self._do_request_for_transaction(
destination, subpath, args=args destination, subpath, args=args
) )
@log_function @log_function
def get_pdu(self, destination, pdu_origin, pdu_id): def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.
Args: Args:
destination (str): The host name of the remote home server we want destination (str): The host name of the remote home server we want
to get the state from. to get the state from.
pdu_origin (str): The home server which created the PDU. event_id (str): The id of the event being requested.
pdu_id (str): The id of the PDU being requested.
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.
""" """
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s", logger.debug("get_pdu dest=%s, event_id=%s",
destination, pdu_origin, pdu_id) destination, event_id)
subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id) subpath = "/event/%s/" % (event_id, )
return self._do_request_for_transaction(destination, subpath) return self._do_request_for_transaction(destination, subpath)
@log_function @log_function
def backfill(self, dest, context, pdu_tuples, limit): def backfill(self, dest, context, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of """ Requests `limit` previous PDUs in a given context before list of
PDUs. PDUs.
Args: Args:
dest (str) dest (str)
context (str) context (str)
pdu_tuples (list) event_tuples (list)
limt (int) limt (int)
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.
""" """
logger.debug( logger.debug(
"backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s", "backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
dest, context, repr(pdu_tuples), str(limit) dest, context, repr(event_tuples), str(limit)
) )
if not pdu_tuples: if not event_tuples:
# TODO: raise?
return return
subpath = "/backfill/%s/" % context subpath = "/backfill/%s/" % (context,)
args = { args = {
"v": ["%s,%s" % (i, o) for i, o in pdu_tuples], "v": event_tuples,
"limit": limit, "limit": limit,
} }
@ -222,11 +220,10 @@ class TransportLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_join(self, destination, context, pdu_id, origin, content): def send_join(self, destination, context, event_id, content):
path = PREFIX + "/send_join/%s/%s/%s" % ( path = PREFIX + "/send_join/%s/%s" % (
context, context,
origin, event_id,
pdu_id,
) )
code, content = yield self.client.put_json( code, content = yield self.client.put_json(
@ -242,11 +239,10 @@ class TransportLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, context, pdu_id, origin, content): def send_invite(self, destination, context, event_id, content):
path = PREFIX + "/invite/%s/%s/%s" % ( path = PREFIX + "/invite/%s/%s" % (
context, context,
origin, event_id,
pdu_id,
) )
code, content = yield self.client.put_json( code, content = yield self.client.put_json(
@ -376,10 +372,10 @@ class TransportLayer(object):
# data_id pair. # data_id pair.
self.server.register_path( self.server.register_path(
"GET", "GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"), re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication( self._with_authentication(
lambda origin, content, query, pdu_origin, pdu_id: lambda origin, content, query, event_id:
handler.on_pdu_request(pdu_origin, pdu_id) handler.on_pdu_request(event_id)
) )
) )
@ -391,8 +387,7 @@ class TransportLayer(object):
lambda origin, content, query, context: lambda origin, content, query, context:
handler.on_context_state_request( handler.on_context_state_request(
context, context,
query.get("pdu_id", [None])[0], query.get("event_id", [None])[0],
query.get("pdu_origin", [None])[0]
) )
) )
) )
@ -442,9 +437,9 @@ class TransportLayer(object):
self.server.register_path( self.server.register_path(
"PUT", "PUT",
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"), re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
self._with_authentication( self._with_authentication(
lambda origin, content, query, context, pdu_origin, pdu_id: lambda origin, content, query, context, event_id:
self._on_send_join_request( self._on_send_join_request(
origin, content, query, origin, content, query,
) )
@ -453,9 +448,9 @@ class TransportLayer(object):
self.server.register_path( self.server.register_path(
"PUT", "PUT",
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"), re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
self._with_authentication( self._with_authentication(
lambda origin, content, query, context, pdu_origin, pdu_id: lambda origin, content, query, context, event_id:
self._on_invite_request( self._on_invite_request(
origin, content, query, origin, content, query,
) )
@ -548,7 +543,7 @@ class TransportLayer(object):
limit = int(limits[-1]) limit = int(limits[-1])
versions = [v.split(",", 1) for v in v_list] versions = v_list
return self.request_handler.on_backfill_request( return self.request_handler.on_backfill_request(
context, versions, limit context, versions, limit
@ -579,120 +574,3 @@ class TransportLayer(object):
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
class TransportReceivedHandler(object):
""" Callbacks used when we receive a transaction
"""
def on_incoming_transaction(self, transaction):
""" Called on PUT /send/<transaction_id>, or on response to a request
that we sent (e.g. a backfill request)
Args:
transaction (synapse.transaction.Transaction): The transaction that
was sent to us.
Returns:
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
class TransportRequestHandler(object):
""" Handlers used when someone want's data from us
"""
def on_pull_request(self, versions):
""" Called on GET /pull/?v=...
This is hit when a remote home server wants to get all data
after a given transaction. Mainly used when a home server comes back
online and wants to get everything it has missed.
Args:
versions (list): A list of transaction_ids that should be used to
determine what PDUs the remote side have not yet seen.
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_pdu_request(self, pdu_origin, pdu_id):
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
Someone wants a particular PDU. This PDU may or may not have originated
from us.
Args:
pdu_origin (str)
pdu_id (str)
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_context_state_request(self, context):
""" Called on GET /state/<context>/
Gets hit when someone wants all the *current* state for a given
contexts.
Args:
context (str): The name of the context that we're interested in.
Returns:
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_backfill_request(self, context, versions, limit):
""" Called on GET /backfill/<context>/?v=...&limit=...
Gets hit when we want to backfill backwards on a given context from
the given point.
Args:
context (str): The context to backfill
versions (list): A list of 2-tuples representing where to backfill
from, in the form `(pdu_id, origin)`
limit (int): How many pdus to return.
Returns:
Deferred: Results in a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_query_request(self):
""" Called on a GET /query/<query_type> request. """

View File

@ -34,13 +34,13 @@ class Pdu(JsonEncodedObject):
A Pdu can be classified as "state". For a given context, we can efficiently A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done retrieve all state pdu's that haven't been clobbered. Clobbering is done
via a unique constraint on the tuple (context, pdu_type, state_key). A pdu via a unique constraint on the tuple (context, type, state_key). A pdu
is a state pdu if `is_state` is True. is a state pdu if `is_state` is True.
Example pdu:: Example pdu::
{ {
"pdu_id": "78c", "event_id": "$78c:example.com",
"origin_server_ts": 1404835423000, "origin_server_ts": 1404835423000,
"origin": "bar", "origin": "bar",
"prev_ids": [ "prev_ids": [
@ -53,14 +53,14 @@ class Pdu(JsonEncodedObject):
""" """
valid_keys = [ valid_keys = [
"pdu_id", "event_id",
"context", "room_id",
"origin", "origin",
"origin_server_ts", "origin_server_ts",
"pdu_type", "type",
"destinations", "destinations",
"transaction_id", "transaction_id",
"prev_pdus", "prev_events",
"depth", "depth",
"content", "content",
"outlier", "outlier",
@ -68,8 +68,7 @@ class Pdu(JsonEncodedObject):
"signatures", "signatures",
"is_state", # Below this are keys valid only for State Pdus. "is_state", # Below this are keys valid only for State Pdus.
"state_key", "state_key",
"prev_state_id", "prev_state",
"prev_state_origin",
"required_power_level", "required_power_level",
"user_id", "user_id",
] ]
@ -81,18 +80,18 @@ class Pdu(JsonEncodedObject):
] ]
required_keys = [ required_keys = [
"pdu_id", "event_id",
"context", "room_id",
"origin", "origin",
"origin_server_ts", "origin_server_ts",
"pdu_type", "type",
"content", "content",
] ]
# TODO: We need to make this properly load content rather than # TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!) # just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[], def __init__(self, destinations=[], is_state=False, prev_events=[],
outlier=False, hashes={}, signatures={}, **kwargs): outlier=False, hashes={}, signatures={}, **kwargs):
if is_state: if is_state:
for required_key in ["state_key"]: for required_key in ["state_key"]:
@ -102,66 +101,13 @@ class Pdu(JsonEncodedObject):
super(Pdu, self).__init__( super(Pdu, self).__init__(
destinations=destinations, destinations=destinations,
is_state=bool(is_state), is_state=bool(is_state),
prev_pdus=prev_pdus, prev_events=prev_events,
outlier=outlier, outlier=outlier,
hashes=hashes, hashes=hashes,
signatures=signatures, signatures=signatures,
**kwargs **kwargs
) )
@classmethod
def from_pdu_tuple(cls, pdu_tuple):
""" Converts a PduTuple to a Pdu
Args:
pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
convert
Returns:
Pdu
"""
if pdu_tuple:
d = copy.copy(pdu_tuple.pdu_entry._asdict())
d["origin_server_ts"] = d.pop("ts")
for k in d.keys():
if d[k] is None:
del d[k]
d["content"] = json.loads(d["content_json"])
del d["content_json"]
args = {f: d[f] for f in cls.valid_keys if f in d}
if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"]))
hashes = {
alg: encode_base64(hsh)
for alg, hsh in pdu_tuple.hashes.items()
}
signatures = {
kid: encode_base64(sig)
for kid, sig in pdu_tuple.signatures.items()
}
prev_pdus = []
for prev_pdu in pdu_tuple.prev_pdu_list:
prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
prev_hashes = {
alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
}
prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
return Pdu(
prev_pdus=prev_pdus,
hashes=hashes,
signatures=signatures,
**args
)
else:
return None
def __str__(self): def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))

View File

@ -139,7 +139,7 @@ class FederationHandler(BaseHandler):
# Huh, let's try and get the current state # Huh, let's try and get the current state
try: try:
yield self.replication_layer.get_state_for_context( yield self.replication_layer.get_state_for_context(
event.origin, event.room_id, pdu.pdu_id, pdu.origin, event.origin, event.room_id, event.event_id,
) )
hosts = yield self.store.get_joined_hosts_for_room( hosts = yield self.store.get_joined_hosts_for_room(
@ -368,11 +368,9 @@ class FederationHandler(BaseHandler):
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, pdu_id, pdu_origin): def get_state_for_pdu(self, event_id):
yield run_on_reactor() yield run_on_reactor()
event_id = EventID.create(pdu_id, pdu_origin, self.hs).to_string()
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
[event_id] [event_id]
) )
@ -406,7 +404,7 @@ class FederationHandler(BaseHandler):
events = yield self.store.get_backfill_events( events = yield self.store.get_backfill_events(
context, context,
[self.pdu_codec.encode_event_id(i, o) for i, o in pdu_list], pdu_list,
limit limit
) )
@ -417,14 +415,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_persisted_pdu(self, pdu_id, origin): def get_persisted_pdu(self, event_id):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
Returns: Returns:
Deferred: Results in a `Pdu`. Deferred: Results in a `Pdu`.
""" """
event = yield self.store.get_event( event = yield self.store.get_event(
self.pdu_codec.encode_event_id(pdu_id, origin), event_id,
allow_none=True, allow_none=True,
) )