Add new API appservice specific public room list
This commit is contained in:
parent
194b6259c5
commit
f32fb65552
|
@ -89,6 +89,9 @@ class ApplicationService(object):
|
||||||
self.namespaces = self._check_namespaces(namespaces)
|
self.namespaces = self._check_namespaces(namespaces)
|
||||||
self.id = id
|
self.id = id
|
||||||
|
|
||||||
|
if "|" in self.id:
|
||||||
|
raise Exception("application service ID cannot contain '|' character")
|
||||||
|
|
||||||
# .protocols is a publicly visible field
|
# .protocols is a publicly visible field
|
||||||
if protocols:
|
if protocols:
|
||||||
self.protocols = set(protocols)
|
self.protocols = set(protocols)
|
||||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
|
@ -177,6 +178,14 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
" valid result", uri)
|
" valid result", uri)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
for instance in info.get("instances", []):
|
||||||
|
instance["appservice_id"] = service.id
|
||||||
|
network_id = instance.get("network_id", None)
|
||||||
|
if network_id is not None:
|
||||||
|
instance["network_id"] = ThirdPartyInstanceID(
|
||||||
|
service.id, network_id,
|
||||||
|
).to_string()
|
||||||
|
|
||||||
defer.returnValue(info)
|
defer.returnValue(info)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("query_3pe_protocol to %s threw exception %s",
|
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||||
|
|
|
@ -655,12 +655,15 @@ class FederationClient(FederationBase):
|
||||||
raise RuntimeError("Failed to send to any server.")
|
raise RuntimeError("Failed to send to any server.")
|
||||||
|
|
||||||
def get_public_rooms(self, destination, limit=None, since_token=None,
|
def get_public_rooms(self, destination, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None):
|
||||||
if destination == self.server_name:
|
if destination == self.server_name:
|
||||||
return
|
return
|
||||||
|
|
||||||
return self.transport_layer.get_public_rooms(
|
return self.transport_layer.get_public_rooms(
|
||||||
destination, limit, since_token, search_filter
|
destination, limit, since_token, search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -249,10 +249,15 @@ class TransportLayerClient(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_public_rooms(self, remote_server, limit, since_token,
|
def get_public_rooms(self, remote_server, limit, since_token,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None):
|
||||||
path = PREFIX + "/publicRooms"
|
path = PREFIX + "/publicRooms"
|
||||||
|
|
||||||
args = {}
|
args = {
|
||||||
|
"include_all_networks": "true" if include_all_networks else "false",
|
||||||
|
}
|
||||||
|
if third_party_instance_id:
|
||||||
|
args["third_party_instance_id"] = third_party_instance_id,
|
||||||
if limit:
|
if limit:
|
||||||
args["limit"] = [str(limit)]
|
args["limit"] = [str(limit)]
|
||||||
if since_token:
|
if since_token:
|
||||||
|
|
|
@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||||
|
parse_boolean_from_args,
|
||||||
)
|
)
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.versionstring import get_version_string
|
from synapse.util.versionstring import get_version_string
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
|
||||||
def on_GET(self, origin, content, query):
|
def on_GET(self, origin, content, query):
|
||||||
limit = parse_integer_from_args(query, "limit", 0)
|
limit = parse_integer_from_args(query, "limit", 0)
|
||||||
since_token = parse_string_from_args(query, "since", None)
|
since_token = parse_string_from_args(query, "since", None)
|
||||||
|
include_all_networks = parse_boolean_from_args(
|
||||||
|
query, "include_all_networks", False
|
||||||
|
)
|
||||||
|
third_party_instance_id = parse_string_from_args(
|
||||||
|
query, "third_party_instance_id", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_all_networks:
|
||||||
|
network_tuple = None
|
||||||
|
elif third_party_instance_id:
|
||||||
|
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||||
|
else:
|
||||||
|
network_tuple = ThirdPartyInstanceID(None, None)
|
||||||
|
|
||||||
data = yield self.room_list_handler.get_local_public_room_list(
|
data = yield self.room_list_handler.get_local_public_room_list(
|
||||||
limit, since_token
|
limit, since_token,
|
||||||
|
network_tuple=network_tuple
|
||||||
)
|
)
|
||||||
defer.returnValue((200, data))
|
defer.returnValue((200, data))
|
||||||
|
|
||||||
|
|
|
@ -339,3 +339,15 @@ class DirectoryHandler(BaseHandler):
|
||||||
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
yield self.auth.check_can_change_room_list(room_id, requester.user)
|
||||||
|
|
||||||
yield self.store.set_room_is_public(room_id, visibility == "public")
|
yield self.store.set_room_is_public(room_id, visibility == "public")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def edit_published_appservice_room_list(self, appservice_id, network_id,
|
||||||
|
room_id, visibility):
|
||||||
|
"""Edit the appservice/network specific public room list.
|
||||||
|
"""
|
||||||
|
if visibility not in ["public", "private"]:
|
||||||
|
raise SynapseError(400, "Invalid visibility setting")
|
||||||
|
|
||||||
|
yield self.store.set_room_is_public_appservice(
|
||||||
|
room_id, appservice_id, network_id, visibility == "public"
|
||||||
|
)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.constants import (
|
||||||
)
|
)
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
from synapse.types import ThirdPartyInstanceID
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from unpaddedbase64 import encode_base64, decode_base64
|
from unpaddedbase64 import encode_base64, decode_base64
|
||||||
|
@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
|
||||||
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
# This is used to indicate we should only return rooms published to the main list.
|
||||||
|
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||||
|
|
||||||
|
|
||||||
class RoomListHandler(BaseHandler):
|
class RoomListHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RoomListHandler, self).__init__(hs)
|
super(RoomListHandler, self).__init__(hs)
|
||||||
|
@ -41,10 +46,27 @@ class RoomListHandler(BaseHandler):
|
||||||
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
||||||
|
|
||||||
def get_local_public_room_list(self, limit=None, since_token=None,
|
def get_local_public_room_list(self, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None,
|
||||||
if search_filter:
|
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||||
|
"""Generate a local public room list.
|
||||||
|
|
||||||
|
There are multiple different lists: the main one plus one per third
|
||||||
|
party network. A client can ask for a specific list or to return all.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit (int)
|
||||||
|
since_token (str)
|
||||||
|
search_filter (dict)
|
||||||
|
network_tuple (ThirdPartyInstanceID): Which public list to use.
|
||||||
|
This can be (None, None) to indicate the main list, or a particular
|
||||||
|
appservice and network id to use an appservice specific one.
|
||||||
|
Setting to None returns all public rooms across all lists.
|
||||||
|
"""
|
||||||
|
if search_filter or network_tuple is not (None, None):
|
||||||
# We explicitly don't bother caching searches.
|
# We explicitly don't bother caching searches.
|
||||||
return self._get_public_room_list(limit, since_token, search_filter)
|
return self._get_public_room_list(
|
||||||
|
limit, since_token, search_filter, network_tuple=network_tuple,
|
||||||
|
)
|
||||||
|
|
||||||
result = self.response_cache.get((limit, since_token))
|
result = self.response_cache.get((limit, since_token))
|
||||||
if not result:
|
if not result:
|
||||||
|
@ -56,7 +78,8 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_public_room_list(self, limit=None, since_token=None,
|
def _get_public_room_list(self, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None,
|
||||||
|
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||||
if since_token and since_token != "END":
|
if since_token and since_token != "END":
|
||||||
since_token = RoomListNextBatch.from_token(since_token)
|
since_token = RoomListNextBatch.from_token(since_token)
|
||||||
else:
|
else:
|
||||||
|
@ -73,14 +96,15 @@ class RoomListHandler(BaseHandler):
|
||||||
current_public_id = yield self.store.get_current_public_room_stream_id()
|
current_public_id = yield self.store.get_current_public_room_stream_id()
|
||||||
public_room_stream_id = since_token.public_room_stream_id
|
public_room_stream_id = since_token.public_room_stream_id
|
||||||
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
|
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
|
||||||
public_room_stream_id, current_public_id
|
public_room_stream_id, current_public_id,
|
||||||
|
network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
stream_token = yield self.store.get_room_max_stream_ordering()
|
stream_token = yield self.store.get_room_max_stream_ordering()
|
||||||
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
|
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
|
||||||
|
|
||||||
room_ids = yield self.store.get_public_room_ids_at_stream_id(
|
room_ids = yield self.store.get_public_room_ids_at_stream_id(
|
||||||
public_room_stream_id
|
public_room_stream_id, network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We want to return rooms in a particular order: the number of joined
|
# We want to return rooms in a particular order: the number of joined
|
||||||
|
@ -311,7 +335,8 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None,):
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We currently don't support searching across federation, so we have
|
# We currently don't support searching across federation, so we have
|
||||||
# to do it manually without pagination
|
# to do it manually without pagination
|
||||||
|
@ -320,6 +345,8 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
res = yield self._get_remote_list_cached(
|
res = yield self._get_remote_list_cached(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if search_filter:
|
if search_filter:
|
||||||
|
@ -332,22 +359,30 @@ class RoomListHandler(BaseHandler):
|
||||||
defer.returnValue(res)
|
defer.returnValue(res)
|
||||||
|
|
||||||
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
|
||||||
search_filter=None):
|
search_filter=None, include_all_networks=False,
|
||||||
|
third_party_instance_id=None,):
|
||||||
repl_layer = self.hs.get_replication_layer()
|
repl_layer = self.hs.get_replication_layer()
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# We can't cache when asking for search
|
# We can't cache when asking for search
|
||||||
return repl_layer.get_public_rooms(
|
return repl_layer.get_public_rooms(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter, include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.remote_response_cache.get((server_name, limit, since_token))
|
key = (
|
||||||
|
server_name, limit, since_token, include_all_networks,
|
||||||
|
third_party_instance_id,
|
||||||
|
)
|
||||||
|
result = self.remote_response_cache.get(key)
|
||||||
if not result:
|
if not result:
|
||||||
result = self.remote_response_cache.set(
|
result = self.remote_response_cache.set(
|
||||||
(server_name, limit, since_token),
|
key,
|
||||||
repl_layer.get_public_rooms(
|
repl_layer.get_public_rooms(
|
||||||
server_name, limit=limit, since_token=since_token,
|
server_name, limit=limit, since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
|
||||||
parameter is present and not one of "true" or "false".
|
parameter is present and not one of "true" or "false".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if name in request.args:
|
return parse_boolean_from_args(request.args, name, default, required)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_boolean_from_args(args, name, default=None, required=False):
|
||||||
|
if name in args:
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"true": True,
|
"true": True,
|
||||||
"false": False,
|
"false": False,
|
||||||
}[request.args[name][0]]
|
}[args[name][0]]
|
||||||
except:
|
except:
|
||||||
message = (
|
message = (
|
||||||
"Boolean query parameter %r must be one of"
|
"Boolean query parameter %r must be one of"
|
||||||
|
|
|
@ -475,7 +475,7 @@ class ReplicationResource(Resource):
|
||||||
)
|
)
|
||||||
upto_token = _position_from_rows(public_rooms_rows, current_position)
|
upto_token = _position_from_rows(public_rooms_rows, current_position)
|
||||||
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
|
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
|
||||||
"position", "room_id", "visibility"
|
"position", "room_id", "visibility", "appservice_id", "network_id",
|
||||||
), position=upto_token)
|
), position=upto_token)
|
||||||
|
|
||||||
def federation(self, writer, current_token, limit, request_streams, federation_ack):
|
def federation(self, writer, current_token, limit, request_streams, federation_ack):
|
||||||
|
|
|
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
ClientDirectoryServer(hs).register(http_server)
|
ClientDirectoryServer(hs).register(http_server)
|
||||||
ClientDirectoryListServer(hs).register(http_server)
|
ClientDirectoryListServer(hs).register(http_server)
|
||||||
|
ClientAppserviceDirectoryListServer(hs).register(http_server)
|
||||||
|
|
||||||
|
|
||||||
class ClientDirectoryServer(ClientV1RestServlet):
|
class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
|
@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
|
||||||
|
PATTERNS = client_path_patterns(
|
||||||
|
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.handlers = hs.get_handlers()
|
||||||
|
|
||||||
|
def on_PUT(self, request, network_id, room_id):
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
visibility = content.get("visibility", "public")
|
||||||
|
return self._edit(request, network_id, room_id, visibility)
|
||||||
|
|
||||||
|
def on_DELETE(self, request, network_id, room_id):
|
||||||
|
return self._edit(request, network_id, room_id, "private")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _edit(self, request, network_id, room_id, visibility):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
if not requester.app_service:
|
||||||
|
raise AuthError(
|
||||||
|
403, "Only appservices can edit the appservice published room list"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||||
|
requester.app_service.id, network_id, room_id, visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.types import UserID, RoomID, RoomAlias
|
from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
|
||||||
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
parse_json_object_from_request, parse_string, parse_integer
|
parse_json_object_from_request, parse_string, parse_integer
|
||||||
|
@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
||||||
since_token = content.get("since", None)
|
since_token = content.get("since", None)
|
||||||
search_filter = content.get("filter", None)
|
search_filter = content.get("filter", None)
|
||||||
|
|
||||||
|
include_all_networks = content.get("include_all_networks", False)
|
||||||
|
third_party_instance_id = content.get("third_party_instance_id", None)
|
||||||
|
|
||||||
|
if include_all_networks:
|
||||||
|
network_tuple = None
|
||||||
|
if third_party_instance_id is not None:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Can't use include_all_networks with an explicit network"
|
||||||
|
)
|
||||||
|
elif third_party_instance_id is None:
|
||||||
|
network_tuple = ThirdPartyInstanceID(None, None)
|
||||||
|
else:
|
||||||
|
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
|
||||||
|
|
||||||
handler = self.hs.get_room_list_handler()
|
handler = self.hs.get_room_list_handler()
|
||||||
if server:
|
if server:
|
||||||
data = yield handler.get_remote_public_room_list(
|
data = yield handler.get_remote_public_room_list(
|
||||||
|
@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
||||||
limit=limit,
|
limit=limit,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
include_all_networks=include_all_networks,
|
||||||
|
third_party_instance_id=third_party_instance_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data = yield handler.get_local_public_room_list(
|
data = yield handler.get_local_public_room_list(
|
||||||
limit=limit,
|
limit=limit,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
search_filter=search_filter,
|
search_filter=search_filter,
|
||||||
|
network_tuple=network_tuple,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, data))
|
defer.returnValue((200, data))
|
||||||
|
|
|
@ -106,7 +106,11 @@ class RoomStore(SQLBaseStore):
|
||||||
entries = self._simple_select_list_txn(
|
entries = self._simple_select_list_txn(
|
||||||
txn,
|
txn,
|
||||||
table="public_room_list_stream",
|
table="public_room_list_stream",
|
||||||
keyvalues={"room_id": room_id},
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"appservice_id": None,
|
||||||
|
"network_id": None,
|
||||||
|
},
|
||||||
retcols=("stream_id", "visibility"),
|
retcols=("stream_id", "visibility"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -124,6 +128,8 @@ class RoomStore(SQLBaseStore):
|
||||||
"stream_id": next_id,
|
"stream_id": next_id,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"visibility": is_public,
|
"visibility": is_public,
|
||||||
|
"appservice_id": None,
|
||||||
|
"network_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -133,6 +139,73 @@ class RoomStore(SQLBaseStore):
|
||||||
set_room_is_public_txn, next_id,
|
set_room_is_public_txn, next_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
|
||||||
|
is_public):
|
||||||
|
"""Edit the appservice/network specific public room list.
|
||||||
|
"""
|
||||||
|
def set_room_is_public_appservice_txn(txn, next_id):
|
||||||
|
if is_public:
|
||||||
|
try:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="appservice_room_list",
|
||||||
|
values={
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": "network_id",
|
||||||
|
"room_id": room_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except self.database_engine.module.IntegrityError:
|
||||||
|
# We've already inserted, nothing to do.
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="appservice_room_list",
|
||||||
|
keyvalues={
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
"room_id": room_id
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entries = self._simple_select_list_txn(
|
||||||
|
txn,
|
||||||
|
table="public_room_list_stream",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
},
|
||||||
|
retcols=("stream_id", "visibility"),
|
||||||
|
)
|
||||||
|
|
||||||
|
entries.sort(key=lambda r: r["stream_id"])
|
||||||
|
|
||||||
|
add_to_stream = True
|
||||||
|
if entries:
|
||||||
|
add_to_stream = bool(entries[-1]["visibility"]) != is_public
|
||||||
|
|
||||||
|
if add_to_stream:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="public_room_list_stream",
|
||||||
|
values={
|
||||||
|
"stream_id": next_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"visibility": is_public,
|
||||||
|
"appservice_id": appservice_id,
|
||||||
|
"network_id": network_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._public_room_id_gen.get_next() as next_id:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"set_room_is_public_appservice",
|
||||||
|
set_room_is_public_appservice_txn, next_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_public_room_ids(self):
|
def get_public_room_ids(self):
|
||||||
return self._simple_select_onecol(
|
return self._simple_select_onecol(
|
||||||
table="rooms",
|
table="rooms",
|
||||||
|
@ -259,38 +332,95 @@ class RoomStore(SQLBaseStore):
|
||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def get_public_room_ids_at_stream_id(self, stream_id):
|
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
|
||||||
|
"""Get pulbic rooms for a particular list, or across all lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id (int)
|
||||||
|
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
|
||||||
|
means the main list, None means all lsits.
|
||||||
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_public_room_ids_at_stream_id",
|
"get_public_room_ids_at_stream_id",
|
||||||
self.get_public_room_ids_at_stream_id_txn, stream_id
|
self.get_public_room_ids_at_stream_id_txn,
|
||||||
|
stream_id, network_tuple=network_tuple
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
|
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
|
||||||
|
network_tuple):
|
||||||
return {
|
return {
|
||||||
rm
|
rm
|
||||||
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
|
for rm, vis in self.get_published_at_stream_id_txn(
|
||||||
|
txn, stream_id, network_tuple=network_tuple
|
||||||
|
).items()
|
||||||
if vis
|
if vis
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_published_at_stream_id_txn(self, txn, stream_id):
|
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
|
||||||
sql = ("""
|
if network_tuple:
|
||||||
SELECT room_id, visibility FROM public_room_list_stream
|
# We want to get from a particular list. No aggregation required.
|
||||||
INNER JOIN (
|
|
||||||
SELECT room_id, max(stream_id) AS stream_id
|
sql = ("""
|
||||||
|
SELECT room_id, visibility FROM public_room_list_stream
|
||||||
|
INNER JOIN (
|
||||||
|
SELECT room_id, max(stream_id) AS stream_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id <= ? %s
|
||||||
|
GROUP BY room_id
|
||||||
|
) grouped USING (room_id, stream_id)
|
||||||
|
""")
|
||||||
|
|
||||||
|
if network_tuple.appservice_id is not None:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id = ? AND network_id = ?",),
|
||||||
|
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
txn.execute(
|
||||||
|
sql % ("AND appservice_id IS NULL",),
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
|
return dict(txn.fetchall())
|
||||||
|
else:
|
||||||
|
# We want to get from all lists, so we need to aggregate the results
|
||||||
|
|
||||||
|
logger.info("Executing full list")
|
||||||
|
|
||||||
|
sql = ("""
|
||||||
|
SELECT room_id, visibility
|
||||||
FROM public_room_list_stream
|
FROM public_room_list_stream
|
||||||
WHERE stream_id <= ?
|
INNER JOIN (
|
||||||
GROUP BY room_id
|
SELECT
|
||||||
) grouped USING (room_id, stream_id)
|
room_id, max(stream_id) AS stream_id, appservice_id,
|
||||||
""")
|
network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
|
WHERE stream_id <= ?
|
||||||
|
GROUP BY room_id, appservice_id, network_id
|
||||||
|
) grouped USING (room_id, stream_id)
|
||||||
|
""")
|
||||||
|
|
||||||
txn.execute(sql, (stream_id,))
|
txn.execute(
|
||||||
return dict(txn.fetchall())
|
sql,
|
||||||
|
(stream_id,)
|
||||||
|
)
|
||||||
|
|
||||||
def get_public_room_changes(self, prev_stream_id, new_stream_id):
|
results = {}
|
||||||
|
# A room is visible if its visible on any list.
|
||||||
|
for room_id, visibility in txn.fetchall():
|
||||||
|
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_public_room_changes(self, prev_stream_id, new_stream_id,
|
||||||
|
network_tuple):
|
||||||
def get_public_room_changes_txn(txn):
|
def get_public_room_changes_txn(txn):
|
||||||
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
|
then_rooms = self.get_public_room_ids_at_stream_id_txn(
|
||||||
|
txn, prev_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
|
now_rooms_dict = self.get_published_at_stream_id_txn(
|
||||||
|
txn, new_stream_id, network_tuple
|
||||||
|
)
|
||||||
|
|
||||||
now_rooms_visible = set(
|
now_rooms_visible = set(
|
||||||
rm for rm, vis in now_rooms_dict.items() if vis
|
rm for rm, vis in now_rooms_dict.items() if vis
|
||||||
|
@ -311,7 +441,8 @@ class RoomStore(SQLBaseStore):
|
||||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||||
def get_all_new_public_rooms(txn):
|
def get_all_new_public_rooms(txn):
|
||||||
sql = ("""
|
sql = ("""
|
||||||
SELECT stream_id, room_id, visibility FROM public_room_list_stream
|
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||||
|
FROM public_room_list_stream
|
||||||
WHERE stream_id > ? AND stream_id <= ?
|
WHERE stream_id > ? AND stream_id <= ?
|
||||||
ORDER BY stream_id ASC
|
ORDER BY stream_id ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE appservice_room_list(
|
||||||
|
appservice_id TEXT NOT NULL,
|
||||||
|
network_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
|
||||||
|
appservice_id, network_id, room_id
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
|
||||||
|
ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;
|
|
@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyInstanceID(
|
||||||
|
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
|
||||||
|
):
|
||||||
|
# Deny iteration because it will bite you if you try to create a singleton
|
||||||
|
# set by:
|
||||||
|
# users = set(user)
|
||||||
|
def __iter__(self):
|
||||||
|
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
||||||
|
|
||||||
|
# Because this class is a namedtuple of strings, it is deeply immutable.
|
||||||
|
def __copy__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo):
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, s):
|
||||||
|
bits = s.split("|", 2)
|
||||||
|
if len(bits) != 2:
|
||||||
|
raise SynapseError(400, "Invalid ID %r" % (s,))
|
||||||
|
|
||||||
|
return cls(appservice_id=bits[0], network_id=bits[1])
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
return "%s|%s" % (self.appservice_id, self.network_id,)
|
||||||
|
|
||||||
|
__str__ = to_string
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, appservice_id, network_id,):
|
||||||
|
return cls(appservice_id=appservice_id, network_id=network_id)
|
||||||
|
|
Loading…
Reference in New Issue