Merge pull request #6279 from matrix-org/erikj/federation_server_async_await
Port federation_server to async/await
This commit is contained in:
commit
b4465564cc
|
@ -0,0 +1 @@
|
||||||
|
Port `federation_server.py` to async/await.
|
|
@ -21,7 +21,6 @@ from six import iteritems
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.internet.abstract import isIPAddress
|
from twisted.internet.abstract import isIPAddress
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
|
||||||
|
@ -86,14 +85,12 @@ class FederationServer(FederationBase):
|
||||||
# come in waves.
|
# come in waves.
|
||||||
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
|
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_backfill_request(self, origin, room_id, versions, limit):
|
||||||
@log_function
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
pdus = yield self.handler.on_backfill_request(
|
pdus = await self.handler.on_backfill_request(
|
||||||
origin, room_id, versions, limit
|
origin, room_id, versions, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,9 +98,7 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_incoming_transaction(self, origin, transaction_data):
|
||||||
@log_function
|
|
||||||
def on_incoming_transaction(self, origin, transaction_data):
|
|
||||||
# keep this as early as possible to make the calculated origin ts as
|
# keep this as early as possible to make the calculated origin ts as
|
||||||
# accurate as possible.
|
# accurate as possible.
|
||||||
request_time = self._clock.time_msec()
|
request_time = self._clock.time_msec()
|
||||||
|
@ -118,18 +113,17 @@ class FederationServer(FederationBase):
|
||||||
# use a linearizer to ensure that we don't process the same transaction
|
# use a linearizer to ensure that we don't process the same transaction
|
||||||
# multiple times in parallel.
|
# multiple times in parallel.
|
||||||
with (
|
with (
|
||||||
yield self._transaction_linearizer.queue(
|
await self._transaction_linearizer.queue(
|
||||||
(origin, transaction.transaction_id)
|
(origin, transaction.transaction_id)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
result = yield self._handle_incoming_transaction(
|
result = await self._handle_incoming_transaction(
|
||||||
origin, transaction, request_time
|
origin, transaction, request_time
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _handle_incoming_transaction(self, origin, transaction, request_time):
|
||||||
def _handle_incoming_transaction(self, origin, transaction, request_time):
|
|
||||||
""" Process an incoming transaction and return the HTTP response
|
""" Process an incoming transaction and return the HTTP response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -140,7 +134,7 @@ class FederationServer(FederationBase):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[(int, object)]: http response code and body
|
Deferred[(int, object)]: http response code and body
|
||||||
"""
|
"""
|
||||||
response = yield self.transaction_actions.have_responded(origin, transaction)
|
response = await self.transaction_actions.have_responded(origin, transaction)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -159,7 +153,7 @@ class FederationServer(FederationBase):
|
||||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||||
|
|
||||||
response = {}
|
response = {}
|
||||||
yield self.transaction_actions.set_response(
|
await self.transaction_actions.set_response(
|
||||||
origin, transaction, 400, response
|
origin, transaction, 400, response
|
||||||
)
|
)
|
||||||
return 400, response
|
return 400, response
|
||||||
|
@ -195,7 +189,7 @@ class FederationServer(FederationBase):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
|
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
|
||||||
continue
|
continue
|
||||||
|
@ -221,11 +215,10 @@ class FederationServer(FederationBase):
|
||||||
# require callouts to other servers to fetch missing events), but
|
# require callouts to other servers to fetch missing events), but
|
||||||
# impose a limit to avoid going too crazy with ram/cpu.
|
# impose a limit to avoid going too crazy with ram/cpu.
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def process_pdus_for_room(room_id):
|
||||||
def process_pdus_for_room(room_id):
|
|
||||||
logger.debug("Processing PDUs for %s", room_id)
|
logger.debug("Processing PDUs for %s", room_id)
|
||||||
try:
|
try:
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Ignoring PDUs for room %s from banned server", room_id)
|
logger.warn("Ignoring PDUs for room %s from banned server", room_id)
|
||||||
for pdu in pdus_by_room[room_id]:
|
for pdu in pdus_by_room[room_id]:
|
||||||
|
@ -237,7 +230,7 @@ class FederationServer(FederationBase):
|
||||||
event_id = pdu.event_id
|
event_id = pdu.event_id
|
||||||
with nested_logging_context(event_id):
|
with nested_logging_context(event_id):
|
||||||
try:
|
try:
|
||||||
yield self._handle_received_pdu(origin, pdu)
|
await self._handle_received_pdu(origin, pdu)
|
||||||
pdu_results[event_id] = {}
|
pdu_results[event_id] = {}
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
logger.warn("Error handling PDU %s: %s", event_id, e)
|
logger.warn("Error handling PDU %s: %s", event_id, e)
|
||||||
|
@ -251,36 +244,33 @@ class FederationServer(FederationBase):
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield concurrently_execute(
|
await concurrently_execute(
|
||||||
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(transaction, "edus"):
|
if hasattr(transaction, "edus"):
|
||||||
for edu in (Edu(**x) for x in transaction.edus):
|
for edu in (Edu(**x) for x in transaction.edus):
|
||||||
yield self.received_edu(origin, edu.edu_type, edu.content)
|
await self.received_edu(origin, edu.edu_type, edu.content)
|
||||||
|
|
||||||
response = {"pdus": pdu_results}
|
response = {"pdus": pdu_results}
|
||||||
|
|
||||||
logger.debug("Returning: %s", str(response))
|
logger.debug("Returning: %s", str(response))
|
||||||
|
|
||||||
yield self.transaction_actions.set_response(origin, transaction, 200, response)
|
await self.transaction_actions.set_response(origin, transaction, 200, response)
|
||||||
return 200, response
|
return 200, response
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def received_edu(self, origin, edu_type, content):
|
||||||
def received_edu(self, origin, edu_type, content):
|
|
||||||
received_edus_counter.inc()
|
received_edus_counter.inc()
|
||||||
yield self.registry.on_edu(edu_type, origin, content)
|
await self.registry.on_edu(edu_type, origin, content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_context_state_request(self, origin, room_id, event_id):
|
||||||
@log_function
|
|
||||||
def on_context_state_request(self, origin, room_id, event_id):
|
|
||||||
if not event_id:
|
if not event_id:
|
||||||
raise NotImplementedError("Specify an event")
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
|
@ -289,8 +279,8 @@ class FederationServer(FederationBase):
|
||||||
# in the cache so we could return it without waiting for the linearizer
|
# in the cache so we could return it without waiting for the linearizer
|
||||||
# - but that's non-trivial to get right, and anyway somewhat defeats
|
# - but that's non-trivial to get right, and anyway somewhat defeats
|
||||||
# the point of the linearizer.
|
# the point of the linearizer.
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
resp = yield self._state_resp_cache.wrap(
|
resp = await self._state_resp_cache.wrap(
|
||||||
(room_id, event_id),
|
(room_id, event_id),
|
||||||
self._on_context_state_request_compute,
|
self._on_context_state_request_compute,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -299,65 +289,58 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return 200, resp
|
return 200, resp
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_state_ids_request(self, origin, room_id, event_id):
|
||||||
def on_state_ids_request(self, origin, room_id, event_id):
|
|
||||||
if not event_id:
|
if not event_id:
|
||||||
raise NotImplementedError("Specify an event")
|
raise NotImplementedError("Specify an event")
|
||||||
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
in_room = await self.auth.check_host_in_room(room_id, origin)
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
|
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||||
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
|
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
|
||||||
|
|
||||||
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _on_context_state_request_compute(self, room_id, event_id):
|
||||||
def _on_context_state_request_compute(self, room_id, event_id):
|
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
|
||||||
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
|
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
|
||||||
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_pdu_request(self, origin, event_id):
|
||||||
@log_function
|
pdu = await self.handler.get_persisted_pdu(origin, event_id)
|
||||||
def on_pdu_request(self, origin, event_id):
|
|
||||||
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
|
|
||||||
|
|
||||||
if pdu:
|
if pdu:
|
||||||
return 200, self._transaction_from_pdus([pdu]).get_dict()
|
return 200, self._transaction_from_pdus([pdu]).get_dict()
|
||||||
else:
|
else:
|
||||||
return 404, ""
|
return 404, ""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_query_request(self, query_type, args):
|
||||||
def on_query_request(self, query_type, args):
|
|
||||||
received_queries_counter.labels(query_type).inc()
|
received_queries_counter.labels(query_type).inc()
|
||||||
resp = yield self.registry.on_query(query_type, args)
|
resp = await self.registry.on_query(query_type, args)
|
||||||
return 200, resp
|
return 200, resp
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
|
||||||
def on_make_join_request(self, origin, room_id, user_id, supported_versions):
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
if room_version not in supported_versions:
|
if room_version not in supported_versions:
|
||||||
logger.warn("Room version %s not in %s", room_version, supported_versions)
|
logger.warn("Room version %s not in %s", room_version, supported_versions)
|
||||||
raise IncompatibleRoomVersionError(room_version=room_version)
|
raise IncompatibleRoomVersionError(room_version=room_version)
|
||||||
|
|
||||||
pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
|
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_invite_request(self, origin, content, room_version):
|
||||||
def on_invite_request(self, origin, content, room_version):
|
|
||||||
if room_version not in KNOWN_ROOM_VERSIONS:
|
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
|
@ -369,28 +352,27 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
pdu = event_from_pdu_json(content, format_ver)
|
pdu = event_from_pdu_json(content, format_ver)
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, pdu.room_id)
|
await self.check_server_matches_acl(origin_host, pdu.room_id)
|
||||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
ret_pdu = await self.handler.on_invite_request(origin, pdu)
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {"event": ret_pdu.get_pdu_json(time_now)}
|
return {"event": ret_pdu.get_pdu_json(time_now)}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_send_join_request(self, origin, content, room_id):
|
||||||
def on_send_join_request(self, origin, content, room_id):
|
|
||||||
logger.debug("on_send_join_request: content: %s", content)
|
logger.debug("on_send_join_request: content: %s", content)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
format_ver = room_version_to_event_format(room_version)
|
format_ver = room_version_to_event_format(room_version)
|
||||||
pdu = event_from_pdu_json(content, format_ver)
|
pdu = event_from_pdu_json(content, format_ver)
|
||||||
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, pdu.room_id)
|
await self.check_server_matches_acl(origin_host, pdu.room_id)
|
||||||
|
|
||||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||||
|
|
||||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||||
|
|
||||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
res_pdus = await self.handler.on_send_join_request(origin, pdu)
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return (
|
return (
|
||||||
200,
|
200,
|
||||||
|
@ -402,48 +384,44 @@ class FederationServer(FederationBase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_make_leave_request(self, origin, room_id, user_id):
|
||||||
def on_make_leave_request(self, origin, room_id, user_id):
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id)
|
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_send_leave_request(self, origin, content, room_id):
|
||||||
def on_send_leave_request(self, origin, content, room_id):
|
|
||||||
logger.debug("on_send_leave_request: content: %s", content)
|
logger.debug("on_send_leave_request: content: %s", content)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
format_ver = room_version_to_event_format(room_version)
|
format_ver = room_version_to_event_format(room_version)
|
||||||
pdu = event_from_pdu_json(content, format_ver)
|
pdu = event_from_pdu_json(content, format_ver)
|
||||||
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, pdu.room_id)
|
await self.check_server_matches_acl(origin_host, pdu.room_id)
|
||||||
|
|
||||||
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
||||||
|
|
||||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||||
|
|
||||||
yield self.handler.on_send_leave_request(origin, pdu)
|
await self.handler.on_send_leave_request(origin, pdu)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_event_auth(self, origin, room_id, event_id):
|
||||||
def on_event_auth(self, origin, room_id, event_id):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
auth_pdus = await self.handler.on_event_auth(event_id)
|
||||||
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
|
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||||
def on_query_auth_request(self, origin, content, room_id, event_id):
|
|
||||||
"""
|
"""
|
||||||
Content is a dict with keys::
|
Content is a dict with keys::
|
||||||
auth_chain (list): A list of events that give the auth chain.
|
auth_chain (list): A list of events that give the auth chain.
|
||||||
|
@ -462,22 +440,22 @@ class FederationServer(FederationBase):
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Results in `dict` with the same format as `content`
|
Deferred: Results in `dict` with the same format as `content`
|
||||||
"""
|
"""
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
format_ver = room_version_to_event_format(room_version)
|
format_ver = room_version_to_event_format(room_version)
|
||||||
|
|
||||||
auth_chain = [
|
auth_chain = [
|
||||||
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
|
event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
|
||||||
]
|
]
|
||||||
|
|
||||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
signed_auth = await self._check_sigs_and_hash_and_fetch(
|
||||||
origin, auth_chain, outlier=True, room_version=room_version
|
origin, auth_chain, outlier=True, room_version=room_version
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = yield self.handler.on_query_auth(
|
ret = await self.handler.on_query_auth(
|
||||||
origin,
|
origin,
|
||||||
event_id,
|
event_id,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -503,16 +481,14 @@ class FederationServer(FederationBase):
|
||||||
return self.on_query_request("user_devices", user_id)
|
return self.on_query_request("user_devices", user_id)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def on_claim_client_keys(self, origin, content):
|
||||||
@log_function
|
|
||||||
def on_claim_client_keys(self, origin, content):
|
|
||||||
query = []
|
query = []
|
||||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||||
for device_id, algorithm in device_keys.items():
|
for device_id, algorithm in device_keys.items():
|
||||||
query.append((user_id, device_id, algorithm))
|
query.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||||
results = yield self.store.claim_e2e_one_time_keys(query)
|
results = await self.store.claim_e2e_one_time_keys(query)
|
||||||
|
|
||||||
json_result = {}
|
json_result = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
|
@ -536,14 +512,12 @@ class FederationServer(FederationBase):
|
||||||
|
|
||||||
return {"one_time_keys": json_result}
|
return {"one_time_keys": json_result}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_get_missing_events(
|
||||||
@log_function
|
|
||||||
def on_get_missing_events(
|
|
||||||
self, origin, room_id, earliest_events, latest_events, limit
|
self, origin, room_id, earliest_events, latest_events, limit
|
||||||
):
|
):
|
||||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
with (await self._server_linearizer.queue((origin, room_id))):
|
||||||
origin_host, _ = parse_server_name(origin)
|
origin_host, _ = parse_server_name(origin)
|
||||||
yield self.check_server_matches_acl(origin_host, room_id)
|
await self.check_server_matches_acl(origin_host, room_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||||
|
@ -553,7 +527,7 @@ class FederationServer(FederationBase):
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
missing_events = yield self.handler.on_get_missing_events(
|
missing_events = await self.handler.on_get_missing_events(
|
||||||
origin, room_id, earliest_events, latest_events, limit
|
origin, room_id, earliest_events, latest_events, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -586,8 +560,7 @@ class FederationServer(FederationBase):
|
||||||
destination=None,
|
destination=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _handle_received_pdu(self, origin, pdu):
|
||||||
def _handle_received_pdu(self, origin, pdu):
|
|
||||||
""" Process a PDU received in a federation /send/ transaction.
|
""" Process a PDU received in a federation /send/ transaction.
|
||||||
|
|
||||||
If the event is invalid, then this method throws a FederationError.
|
If the event is invalid, then this method throws a FederationError.
|
||||||
|
@ -640,37 +613,34 @@ class FederationServer(FederationBase):
|
||||||
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
||||||
|
|
||||||
# We've already checked that we know the room version by this point
|
# We've already checked that we know the room version by this point
|
||||||
room_version = yield self.store.get_room_version(pdu.room_id)
|
room_version = await self.store.get_room_version(pdu.room_id)
|
||||||
|
|
||||||
# Check signature.
|
# Check signature.
|
||||||
try:
|
try:
|
||||||
pdu = yield self._check_sigs_and_hash(room_version, pdu)
|
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
|
raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
|
||||||
|
|
||||||
yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<ReplicationLayer(%s)>" % self.server_name
|
return "<ReplicationLayer(%s)>" % self.server_name
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def exchange_third_party_invite(
|
||||||
def exchange_third_party_invite(
|
|
||||||
self, sender_user_id, target_user_id, room_id, signed
|
self, sender_user_id, target_user_id, room_id, signed
|
||||||
):
|
):
|
||||||
ret = yield self.handler.exchange_third_party_invite(
|
ret = await self.handler.exchange_third_party_invite(
|
||||||
sender_user_id, target_user_id, room_id, signed
|
sender_user_id, target_user_id, room_id, signed
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_exchange_third_party_invite_request(self, room_id, event_dict):
|
||||||
def on_exchange_third_party_invite_request(self, room_id, event_dict):
|
ret = await self.handler.on_exchange_third_party_invite_request(
|
||||||
ret = yield self.handler.on_exchange_third_party_invite_request(
|
|
||||||
room_id, event_dict
|
room_id, event_dict
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_server_matches_acl(self, server_name, room_id):
|
||||||
def check_server_matches_acl(self, server_name, room_id):
|
|
||||||
"""Check if the given server is allowed by the server ACLs in the room
|
"""Check if the given server is allowed by the server ACLs in the room
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -680,13 +650,13 @@ class FederationServer(FederationBase):
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the server does not match the ACL
|
AuthError if the server does not match the ACL
|
||||||
"""
|
"""
|
||||||
state_ids = yield self.store.get_current_state_ids(room_id)
|
state_ids = await self.store.get_current_state_ids(room_id)
|
||||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
||||||
|
|
||||||
if not acl_event_id:
|
if not acl_event_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
acl_event = yield self.store.get_event(acl_event_id)
|
acl_event = await self.store.get_event(acl_event_id)
|
||||||
if server_matches_acl_event(server_name, acl_event):
|
if server_matches_acl_event(server_name, acl_event):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -799,15 +769,14 @@ class FederationHandlerRegistry(object):
|
||||||
|
|
||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_edu(self, edu_type, origin, content):
|
||||||
def on_edu(self, edu_type, origin, content):
|
|
||||||
handler = self.edu_handlers.get(edu_type)
|
handler = self.edu_handlers.get(edu_type)
|
||||||
if not handler:
|
if not handler:
|
||||||
logger.warn("No handler registered for EDU type %s", edu_type)
|
logger.warn("No handler registered for EDU type %s", edu_type)
|
||||||
|
|
||||||
with start_active_span_from_edu(content, "handle_edu"):
|
with start_active_span_from_edu(content, "handle_edu"):
|
||||||
try:
|
try:
|
||||||
yield handler(origin, content)
|
await handler(origin, content)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
logger.info("Failed to handle edu %r: %r", edu_type, e)
|
logger.info("Failed to handle edu %r: %r", edu_type, e)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -840,7 +809,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||||
|
|
||||||
super(ReplicationFederationHandlerRegistry, self).__init__()
|
super(ReplicationFederationHandlerRegistry, self).__init__()
|
||||||
|
|
||||||
def on_edu(self, edu_type, origin, content):
|
async def on_edu(self, edu_type, origin, content):
|
||||||
"""Overrides FederationHandlerRegistry
|
"""Overrides FederationHandlerRegistry
|
||||||
"""
|
"""
|
||||||
if not self.config.use_presence and edu_type == "m.presence":
|
if not self.config.use_presence and edu_type == "m.presence":
|
||||||
|
@ -848,17 +817,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||||
|
|
||||||
handler = self.edu_handlers.get(edu_type)
|
handler = self.edu_handlers.get(edu_type)
|
||||||
if handler:
|
if handler:
|
||||||
return super(ReplicationFederationHandlerRegistry, self).on_edu(
|
return await super(ReplicationFederationHandlerRegistry, self).on_edu(
|
||||||
edu_type, origin, content
|
edu_type, origin, content
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
||||||
|
|
||||||
def on_query(self, query_type, args):
|
async def on_query(self, query_type, args):
|
||||||
"""Overrides FederationHandlerRegistry
|
"""Overrides FederationHandlerRegistry
|
||||||
"""
|
"""
|
||||||
handler = self.query_handlers.get(query_type)
|
handler = self.query_handlers.get(query_type)
|
||||||
if handler:
|
if handler:
|
||||||
return handler(args)
|
return await handler(args)
|
||||||
|
|
||||||
return self._get_query_client(query_type=query_type, args=args)
|
return await self._get_query_client(query_type=query_type, args=args)
|
||||||
|
|
|
@ -138,7 +138,7 @@ def concurrently_execute(func, args, limit):
|
||||||
the number of concurrent executions.
|
the number of concurrent executions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (func): Function to execute, should return a deferred.
|
func (func): Function to execute, should return a deferred or coroutine.
|
||||||
args (list): List of arguments to pass to func, each invocation of func
|
args (list): List of arguments to pass to func, each invocation of func
|
||||||
gets a signle argument.
|
gets a signle argument.
|
||||||
limit (int): Maximum number of conccurent executions.
|
limit (int): Maximum number of conccurent executions.
|
||||||
|
@ -148,11 +148,10 @@ def concurrently_execute(func, args, limit):
|
||||||
"""
|
"""
|
||||||
it = iter(args)
|
it = iter(args)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _concurrently_execute_inner():
|
||||||
def _concurrently_execute_inner():
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield func(next(it))
|
await maybe_awaitable(func(next(it)))
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.datastore.get_to_device_stream_token = lambda: 0
|
self.datastore.get_to_device_stream_token = lambda: 0
|
||||||
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
|
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
|
||||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||||
|
self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
def test_started_typing_local(self):
|
def test_started_typing_local(self):
|
||||||
self.room_members = [U_APPLE, U_BANANA]
|
self.room_members = [U_APPLE, U_BANANA]
|
||||||
|
|
Loading…
Reference in New Issue