Add bulk group publicised lookup API
This commit is contained in:
parent
b880ff190a
commit
ef8e578677
|
@ -812,3 +812,18 @@ class TransportLayerClient(object):
|
|||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def bulk_get_publicised_groups(self, destination, user_ids):
|
||||
"""Get the groups a list of users are publicising
|
||||
"""
|
||||
|
||||
path = PREFIX + "/get_groups_publicised"
|
||||
|
||||
content = {"user_ids": user_ids}
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
|
|
@ -1050,6 +1050,22 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/get_groups_publicised$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
resp = yield self.handler.bulk_get_publicised_groups(
|
||||
content["user_ids"], proxy=False,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
FEDERATION_SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
FederationPullServlet,
|
||||
|
@ -1102,6 +1118,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
|
|||
GROUP_LOCAL_SERVLET_CLASSES = (
|
||||
FederationGroupsLocalInviteServlet,
|
||||
FederationGroupsRemoveLocalUserServlet,
|
||||
FederationGroupsBulkPublicisedServlet,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -313,3 +313,45 @@ class GroupsLocalHandler(object):
|
|||
def get_joined_groups(self, user_id):
|
||||
group_ids = yield self.store.get_joined_groups(user_id)
|
||||
defer.returnValue({"groups": group_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_publicised_groups_for_user(self, user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
get_domain_from_id(user_id), user_id
|
||||
)
|
||||
# TODO: Verify attestations
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
destinations = {}
|
||||
locals = []
|
||||
|
||||
for user_id in user_ids:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
locals.append(user_id)
|
||||
else:
|
||||
destinations.setdefault(
|
||||
get_domain_from_id(user_id), []
|
||||
).append(user_id)
|
||||
|
||||
if not proxy and destinations:
|
||||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
for destination, dest_user_ids in destinations.iteritems():
|
||||
r = yield self.transport_client.bulk_get_publicised_groups(
|
||||
destination, dest_user_ids,
|
||||
)
|
||||
results.update(r)
|
||||
|
||||
for uid in locals:
|
||||
results[uid] = yield self.store.get_publicised_groups_for_user(
|
||||
uid
|
||||
)
|
||||
|
||||
defer.returnValue({"users": results})
|
||||
|
|
|
@ -584,6 +584,59 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
|||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class PublicisedGroupsForUserServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUserServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
user_ids = content["user_ids"]
|
||||
|
||||
result = yield self.groups_handler.bulk_get_publicised_groups(
|
||||
user_ids
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupsForUserServlet(RestServlet):
|
||||
"""Get all groups the logged in user is joined to
|
||||
"""
|
||||
|
@ -627,3 +680,4 @@ def register_servlets(hs, http_server):
|
|||
GroupRolesServlet(hs).register(http_server)
|
||||
GroupSelfUpdatePublicityServlet(hs).register(http_server)
|
||||
GroupSummaryUsersRoleServlet(hs).register(http_server)
|
||||
PublicisedGroupsForUserServlet(hs).register(http_server)
|
||||
|
|
|
@ -835,6 +835,20 @@ class GroupServerStore(SQLBaseStore):
|
|||
desc="add_room_to_group",
|
||||
)
|
||||
|
||||
def get_publicised_groups_for_user(self, user_id):
|
||||
"""Get all groups a user is publicising
|
||||
"""
|
||||
return self._simple_select_onecol(
|
||||
table="local_group_membership",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"membership": "join",
|
||||
"is_publicised": True,
|
||||
},
|
||||
retcol="group_id",
|
||||
desc="get_publicised_groups_for_user",
|
||||
)
|
||||
|
||||
def update_group_publicity(self, group_id, user_id, publicise):
|
||||
"""Update whether the user is publicising their membership of the group
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue