Use state groups to get current state. Make join dance actually work.
This commit is contained in:
parent
f71627567b
commit
5ffe5ab43f
|
@ -22,6 +22,7 @@ from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
|
|||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
|
||||
RoomJoinRulesEvent, RoomOpsPowerLevelsEvent, InviteJoinEvent,
|
||||
RoomCreateEvent,
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
|
@ -59,6 +60,10 @@ class Auth(object):
|
|||
|
||||
is_state = hasattr(event, "state_key")
|
||||
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
# FIXME
|
||||
defer.returnValue(True)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
yield self._can_replace_state(event)
|
||||
allowed = yield self.is_membership_change_allowed(event)
|
||||
|
|
|
@ -403,11 +403,18 @@ class ReplicationLayer(object):
|
|||
defer.returnValue(
|
||||
(404, "No handler for Query type '%s'" % (query_type, ))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_make_join_request(self, context, user_id):
|
||||
pdu = yield self.handler.on_make_join_request(context, user_id)
|
||||
defer.returnValue(pdu.get_dict())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content):
|
||||
pdu = Pdu(**content)
|
||||
ret_pdu = yield self.handler.on_send_join_request(origin, pdu)
|
||||
defer.returnValue((200, ret_pdu.get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
pdu = Pdu(**content)
|
||||
|
@ -426,8 +433,9 @@ class ReplicationLayer(object):
|
|||
|
||||
defer.returnValue(Pdu(**pdu_dict))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_join(self, destination, pdu):
|
||||
return self.transport_layer.send_join(
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination,
|
||||
pdu.context,
|
||||
pdu.pdu_id,
|
||||
|
@ -435,6 +443,13 @@ class ReplicationLayer(object):
|
|||
pdu.get_dict(),
|
||||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
pdus = [Pdu(outlier=True, **p) for p in content.get("pdus", [])]
|
||||
for pdu in pdus:
|
||||
yield self._handle_new_pdu(destination, pdu)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, pdu_id, pdu_origin):
|
||||
|
|
|
@ -229,13 +229,36 @@ class TransportLayer(object):
|
|||
pdu_id,
|
||||
)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
code, content = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
if not 200 <= code < 300:
|
||||
raise RuntimeError("Got %d from send_join", code)
|
||||
|
||||
defer.returnValue(json.loads(content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite(self, destination, context, pdu_id, origin, content):
|
||||
path = PREFIX + "/invite/%s/%s/%s" % (
|
||||
context,
|
||||
origin,
|
||||
pdu_id,
|
||||
)
|
||||
|
||||
code, content = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
)
|
||||
|
||||
if not 200 <= code < 300:
|
||||
raise RuntimeError("Got %d from send_invite", code)
|
||||
|
||||
defer.returnValue(json.loads(content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _authenticate_request(self, request):
|
||||
|
@ -297,9 +320,13 @@ class TransportLayer(object):
|
|||
@defer.inlineCallbacks
|
||||
def new_handler(request, *args, **kwargs):
|
||||
(origin, content) = yield self._authenticate_request(request)
|
||||
response = yield handler(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
try:
|
||||
response = yield handler(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
except:
|
||||
logger.exception("Callback failed")
|
||||
raise
|
||||
defer.returnValue(response)
|
||||
return new_handler
|
||||
|
||||
|
@ -419,6 +446,17 @@ class TransportLayer(object):
|
|||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"PUT",
|
||||
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, pdu_origin, pdu_id:
|
||||
self._on_invite_request(
|
||||
origin, content, query,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_send_request(self, origin, content, query, transaction_id):
|
||||
|
@ -524,6 +562,15 @@ class TransportLayer(object):
|
|||
|
||||
defer.returnValue((200, content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_invite_request(self, origin, content, query):
|
||||
content = yield self.request_handler.on_invite_request(
|
||||
origin, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
class TransportReceivedHandler(object):
|
||||
""" Callbacks used when we receive a transaction
|
||||
|
|
|
@ -62,6 +62,9 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.pdu_codec = PduCodec(hs)
|
||||
|
||||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event, snapshot):
|
||||
|
@ -95,22 +98,25 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
logger.debug("Got event: %s", event.event_id)
|
||||
|
||||
if event.room_id in self.room_queues:
|
||||
self.room_queues[event.room_id].append(pdu)
|
||||
return
|
||||
|
||||
if state:
|
||||
state = [self.pdu_codec.event_from_pdu(p) for p in state]
|
||||
state = {(e.type, e.state_key): e for e in state}
|
||||
yield self.state_handler.annotate_state_groups(event, state=state)
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(
|
||||
event,
|
||||
state=state
|
||||
)
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
|
||||
if not backfilled:
|
||||
yield self.auth.check(event, None, raises=True)
|
||||
|
||||
if event.is_state and not backfilled:
|
||||
is_new_state = yield self.state_handler.handle_new_state(
|
||||
pdu
|
||||
)
|
||||
else:
|
||||
is_new_state = False
|
||||
is_new_state = is_new_state and not backfilled
|
||||
|
||||
# TODO: Implement something in federation that allows us to
|
||||
# respond to PDU.
|
||||
|
@ -211,6 +217,8 @@ class FederationHandler(BaseHandler):
|
|||
assert(event.state_key == joinee)
|
||||
assert(event.room_id == room_id)
|
||||
|
||||
self.room_queues[room_id] = []
|
||||
|
||||
event.event_id = self.event_factory.create_event_id()
|
||||
event.content = content
|
||||
|
||||
|
@ -219,15 +227,14 @@ class FederationHandler(BaseHandler):
|
|||
self.pdu_codec.pdu_from_event(event)
|
||||
)
|
||||
|
||||
# TODO (erikj): Time out here.
|
||||
d = defer.Deferred()
|
||||
self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
|
||||
reactor.callLater(10, d.cancel)
|
||||
state = [self.pdu_codec.event_from_pdu(p) for p in state]
|
||||
|
||||
try:
|
||||
yield d
|
||||
except defer.CancelledError:
|
||||
raise SynapseError(500, "Unable to join remote room")
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(
|
||||
event,
|
||||
state=state
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
|
@ -239,6 +246,32 @@ class FederationHandler(BaseHandler):
|
|||
# FIXME
|
||||
pass
|
||||
|
||||
for e in state:
|
||||
# FIXME: Auth these.
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(
|
||||
e,
|
||||
state=state
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
e,
|
||||
backfilled=False,
|
||||
is_new_state=False
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
|
||||
room_queue = self.room_queues[room_id]
|
||||
del self.room_queues[room_id]
|
||||
|
||||
for p in room_queue:
|
||||
p.outlier = True
|
||||
yield self.on_receive_pdu(p, backfilled=False)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -264,13 +297,9 @@ class FederationHandler(BaseHandler):
|
|||
def on_send_join_request(self, origin, pdu):
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
yield self.state_handler.annotate_state_groups(event)
|
||||
is_new_state= yield self.state_handler.annotate_state_groups(event)
|
||||
yield self.auth.check(event, None, raises=True)
|
||||
|
||||
is_new_state = yield self.state_handler.handle_new_state(
|
||||
pdu
|
||||
)
|
||||
|
||||
# FIXME (erikj): All this is duplicated above :(
|
||||
|
||||
yield self.store.persist_event(
|
||||
|
@ -303,7 +332,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
yield self.replication_layer.send_pdu(new_pdu)
|
||||
|
||||
defer.returnValue(event.state_events.values())
|
||||
defer.returnValue([
|
||||
self.pdu_codec.pdu_from_event(e)
|
||||
for e in event.state_events.values()
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, pdu_id, pdu_origin):
|
||||
|
|
|
@ -199,7 +199,7 @@ class MessageHandler(BaseHandler):
|
|||
raise RoomError(
|
||||
403, "Member does not meet private room rules.")
|
||||
|
||||
data = yield self.store.get_current_state(
|
||||
data = yield self.state_handler.get_current_state(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
defer.returnValue(data)
|
||||
|
@ -238,7 +238,7 @@ class MessageHandler(BaseHandler):
|
|||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
# TODO: This is duplicating logic from snapshot_all_rooms
|
||||
current_state = yield self.store.get_current_state(room_id)
|
||||
current_state = yield self.state_handler.get_current_state(room_id)
|
||||
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -315,7 +315,7 @@ class MessageHandler(BaseHandler):
|
|||
"end": end_token.to_string(),
|
||||
}
|
||||
|
||||
current_state = yield self.store.get_current_state(
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
d["state"] = [self.hs.serialize_event(c) for c in current_state]
|
||||
|
|
|
@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
|
|||
from synapse.rest.transactions import HttpTransactionStore
|
||||
import re
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_path_pattern(path_regex):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
|
|
|
@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.rest.base import RestServlet, client_path_pattern
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class EventStreamRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/events$")
|
||||
|
@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
|
||||
timeout=timeout)
|
||||
chunk = yield handler.get_stream(
|
||||
auth_user.to_string(), pagin_config, timeout=timeout
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
raise
|
||||
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.federation.pdu_codec import encode_event_id, decode_event_id
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.federation.pdu_codec import encode_event_id
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
@ -130,43 +131,23 @@ class StateHandler(object):
|
|||
defer.returnValue(is_new)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def annotate_state_groups(self, event, state=None):
|
||||
if state:
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = state
|
||||
event.state_events = {(s.type, s.state_key): s for s in state}
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
event.prev_events
|
||||
)
|
||||
if hasattr(event, "outlier") and event.outlier:
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = None
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
state = {}
|
||||
state_sets = {}
|
||||
for group in state_groups:
|
||||
for s in group.state:
|
||||
state.setdefault((s.type, s.state_key), []).append(s)
|
||||
|
||||
state_sets.setdefault(
|
||||
(s.type, s.state_key),
|
||||
set()
|
||||
).add(s.event_id)
|
||||
|
||||
unconflicted_state = {
|
||||
k: state[k].pop() for k, v in state_sets.items()
|
||||
if len(v) == 1
|
||||
}
|
||||
|
||||
conflicted_state = {
|
||||
k: state[k]
|
||||
for k, v in state_sets.items()
|
||||
if len(v) > 1
|
||||
}
|
||||
|
||||
new_state = {}
|
||||
new_state.update(unconflicted_state)
|
||||
for key, events in conflicted_state.items():
|
||||
new_state[key] = yield self.resolve(events)
|
||||
new_state = yield self.resolve_state_groups(event.prev_events)
|
||||
|
||||
event.old_state_events = new_state
|
||||
|
||||
|
@ -176,8 +157,63 @@ class StateHandler(object):
|
|||
event.state_group = None
|
||||
event.state_events = new_state
|
||||
|
||||
defer.returnValue(hasattr(event, "state_key"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve(self, events):
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
# FIXME: HACK!
|
||||
pdus = yield self.store.get_latest_pdus_in_context(room_id)
|
||||
|
||||
event_ids = [encode_event_id(p.pdu_id, p.origin) for p in pdus]
|
||||
|
||||
res = self.resolve_state_groups(event_ids)
|
||||
|
||||
if event_type:
|
||||
defer.returnValue(res.get((event_type, state_key)))
|
||||
return
|
||||
|
||||
defer.returnValue(res.values())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(self, event_ids):
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
event_ids
|
||||
)
|
||||
|
||||
state = {}
|
||||
for group in state_groups:
|
||||
for s in group.state:
|
||||
state.setdefault(
|
||||
(s.type, s.state_key),
|
||||
{}
|
||||
)[s.event_id] = s
|
||||
|
||||
unconflicted_state = {
|
||||
k: v.values()[0] for k, v in state.items()
|
||||
if len(v.values()) == 1
|
||||
}
|
||||
|
||||
conflicted_state = {
|
||||
k: v.values()
|
||||
for k, v in state.items()
|
||||
if len(v.values()) > 1
|
||||
}
|
||||
|
||||
try:
|
||||
new_state = {}
|
||||
new_state.update(unconflicted_state)
|
||||
for key, events in conflicted_state.items():
|
||||
new_state[key] = yield self._resolve_state_events(events)
|
||||
except:
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
||||
defer.returnValue(new_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _resolve_state_events(self, events):
|
||||
curr_events = events
|
||||
|
||||
new_powers_deferreds = []
|
||||
|
|
|
@ -277,6 +277,12 @@ class PduStore(SQLBaseStore):
|
|||
(context, depth)
|
||||
)
|
||||
|
||||
def get_latest_pdus_in_context(self, context):
|
||||
return self.runInteraction(
|
||||
self._get_latest_pdus_in_context,
|
||||
context
|
||||
)
|
||||
|
||||
def _get_latest_pdus_in_context(self, txn, context):
|
||||
"""Get's a list of the most current pdus for a given context. This is
|
||||
used when we are sending a Pdu and need to fill out the `prev_pdus`
|
||||
|
|
|
@ -63,6 +63,9 @@ class StateStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event):
|
||||
if not event.state_events:
|
||||
return
|
||||
|
||||
state_group = event.state_group
|
||||
if not state_group:
|
||||
state_group = self._simple_insert_txn(
|
||||
|
|
Loading…
Reference in New Issue