Merge pull request #1026 from matrix-org/paul/thirdpartylookup
3rd party entity lookup
This commit is contained in:
commit
5674ea3e6c
|
@ -81,13 +81,17 @@ class ApplicationService(object):
|
|||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||
|
||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
||||
sender=None, id=None):
|
||||
sender=None, id=None, protocols=None):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
if protocols:
|
||||
self.protocols = set(protocols)
|
||||
else:
|
||||
self.protocols = set()
|
||||
|
||||
def _check_namespaces(self, namespaces):
|
||||
# Sanity check that it is of the form:
|
||||
|
@ -219,6 +223,9 @@ class ApplicationService(object):
|
|||
or user_id == self.sender
|
||||
)
|
||||
|
||||
def is_interested_in_protocol(self, protocol):
|
||||
return protocol in self.protocols
|
||||
|
||||
def is_exclusive_alias(self, alias):
|
||||
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
@ -24,6 +25,28 @@ import urllib
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_valid_3pe_result(r, field):
|
||||
if not isinstance(r, dict):
|
||||
return False
|
||||
|
||||
for k in (field, "protocol"):
|
||||
if k not in r:
|
||||
return False
|
||||
if not isinstance(r[k], str):
|
||||
return False
|
||||
|
||||
if "fields" not in r:
|
||||
return False
|
||||
fields = r["fields"]
|
||||
if not isinstance(fields, dict):
|
||||
return False
|
||||
for k in fields.keys():
|
||||
if not isinstance(fields[k], str):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ApplicationServiceApi(SimpleHttpClient):
|
||||
"""This class manages HS -> AS communications, including querying and
|
||||
pushing.
|
||||
|
@ -71,6 +94,43 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning("query_alias to %s threw exception %s", uri, ex)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, service, kind, protocol, fields):
|
||||
if kind == ThirdPartyEntityKind.USER:
|
||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "userid"
|
||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "alias"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||
)
|
||||
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
if not isinstance(response, list):
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid response %r",
|
||||
uri, response
|
||||
)
|
||||
defer.returnValue([])
|
||||
|
||||
ret = []
|
||||
for r in response:
|
||||
if _is_valid_3pe_result(r, field=required_field):
|
||||
ret.append(r)
|
||||
else:
|
||||
logger.warning(
|
||||
"query_3pe to %s returned an invalid result %r",
|
||||
uri, r
|
||||
)
|
||||
|
||||
defer.returnValue(ret)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||
defer.returnValue([])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_bulk(self, service, events, txn_id=None):
|
||||
events = self._serialize(events)
|
||||
|
|
|
@ -123,6 +123,15 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
)
|
||||
# protocols check
|
||||
protocols = as_info.get("protocols")
|
||||
if protocols:
|
||||
# Because strings are lists in python
|
||||
if isinstance(protocols, str) or not isinstance(protocols, list):
|
||||
raise KeyError("Optional 'protocols' must be a list if present.")
|
||||
for p in protocols:
|
||||
if not isinstance(p, str):
|
||||
raise KeyError("Bad value for 'protocols' item")
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
url=as_info["url"],
|
||||
|
@ -130,4 +139,5 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["id"],
|
||||
protocols=protocols,
|
||||
)
|
||||
|
|
|
@ -159,6 +159,22 @@ class ApplicationServicesHandler(object):
|
|||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield defer.DeferredList([
|
||||
self.appservice_api.query_3pe(service, kind, protocol, fields)
|
||||
for service in services
|
||||
], consumeErrors=True)
|
||||
|
||||
ret = []
|
||||
for (success, result) in results:
|
||||
if success:
|
||||
ret.extend(result)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_event(self, event):
|
||||
"""Retrieve a list of application services interested in this event.
|
||||
|
@ -187,6 +203,14 @@ class ApplicationServicesHandler(object):
|
|||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_3pn(self, protocol):
|
||||
services = yield self.store.get_app_services()
|
||||
interested_list = [
|
||||
s for s in services if s.is_interested_in_protocol(protocol)
|
||||
]
|
||||
defer.returnValue(interested_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_unknown_user(self, user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
|
|
|
@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import (
|
|||
report_event,
|
||||
openid,
|
||||
devices,
|
||||
thirdparty,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -92,3 +93,4 @@ class ClientRestResource(JsonResource):
|
|||
report_event.register_servlets(hs, client_resource)
|
||||
openid.register_servlets(hs, client_resource)
|
||||
devices.register_servlets(hs, client_resource)
|
||||
thirdparty.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThirdPartyUserServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyUserServlet, self).__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, protocol):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.USER, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
||||
class ThirdPartyLocationServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThirdPartyLocationServlet, self).__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, protocol):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ThirdPartyUserServlet(hs).register(http_server)
|
||||
ThirdPartyLocationServlet(hs).register(http_server)
|
|
@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||
# exact values; always pass or compare symbolically
|
||||
class ThirdPartyEntityKind(object):
|
||||
USER = 'user'
|
||||
LOCATION = 'location'
|
||||
|
|
Loading…
Reference in New Issue