synapse-old/synapse/config/oembed.py

174 lines
6.1 KiB
Python

# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from typing import Any, Dict, Iterable, List, Optional, Pattern
from urllib import parse as urlparse
import attr
import pkg_resources
from synapse.types import JsonDict, StrSequence
from ._base import Config, ConfigError
from ._util import validate_config
@attr.s(slots=True, frozen=True, auto_attribs=True)
class OEmbedEndpointConfig:
# The API endpoint to fetch.
api_endpoint: str
# The patterns to match.
url_patterns: List[Pattern]
# The supported formats.
formats: Optional[List[str]]
class OembedConfig(Config):
"""oEmbed Configuration"""
section = "oembed"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
oembed_config: Dict[str, Any] = config.get("oembed") or {}
# A list of patterns which will be used.
self.oembed_patterns: List[OEmbedEndpointConfig] = list(
self._parse_and_validate_providers(oembed_config)
)
def _parse_and_validate_providers(
self, oembed_config: dict
) -> Iterable[OEmbedEndpointConfig]:
"""Extract and parse the oEmbed providers from the given JSON file.
Returns a generator which yields the OidcProviderConfig objects
"""
# Whether to use the packaged providers.json file.
if not oembed_config.get("disable_default_providers") or False:
with pkg_resources.resource_stream("synapse", "res/providers.json") as s:
providers = json.load(s)
yield from self._parse_and_validate_provider(
providers, config_path=("oembed",)
)
# The JSON files which includes additional provider information.
for i, file in enumerate(oembed_config.get("additional_providers") or []):
# TODO Error checking.
with open(file) as f:
providers = json.load(f)
yield from self._parse_and_validate_provider(
providers,
config_path=(
"oembed",
"additional_providers",
f"<item {i}>",
),
)
def _parse_and_validate_provider(
self, providers: List[JsonDict], config_path: StrSequence
) -> Iterable[OEmbedEndpointConfig]:
# Ensure it is the proper form.
validate_config(
_OEMBED_PROVIDER_SCHEMA,
providers,
config_path=config_path,
)
# Parse it and yield each result.
for provider in providers:
# Each provider might have multiple API endpoints, each which
# might have multiple patterns to match.
for endpoint in provider["endpoints"]:
api_endpoint = endpoint["url"]
# The API endpoint must be an HTTP(S) URL.
results = urlparse.urlparse(api_endpoint)
if results.scheme not in {"http", "https"}:
raise ConfigError(
f"Unsupported oEmbed scheme ({results.scheme}) for endpoint {api_endpoint}",
config_path,
)
patterns = [
self._glob_to_pattern(glob, config_path)
for glob in endpoint["schemes"]
]
yield OEmbedEndpointConfig(
api_endpoint, patterns, endpoint.get("formats")
)
def _glob_to_pattern(self, glob: str, config_path: StrSequence) -> Pattern:
"""
Convert the glob into a sane regular expression to match against. The
rules followed will be slightly different for the domain portion vs.
the rest.
1. The scheme must be one of HTTP / HTTPS (and have no globs).
2. The domain can have globs, but we limit it to characters that can
reasonably be a domain part.
TODO: This does not attempt to handle Unicode domain names.
TODO: The domain should not allow wildcard TLDs.
3. Other parts allow a glob to be any one, or more, characters.
"""
results = urlparse.urlparse(glob)
# The scheme must be HTTP(S) (and cannot contain wildcards).
if results.scheme not in {"http", "https"}:
raise ConfigError(
f"Unsupported oEmbed scheme ({results.scheme}) for pattern: {glob}",
config_path,
)
pattern = urlparse.urlunparse(
[
results.scheme,
re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
]
+ [re.escape(part).replace("\\*", ".+") for part in results[2:]]
)
return re.compile(pattern)
_OEMBED_PROVIDER_SCHEMA = {
"type": "array",
"items": {
"type": "object",
"properties": {
"provider_name": {"type": "string"},
"provider_url": {"type": "string"},
"endpoints": {
"type": "array",
"items": {
"type": "object",
"properties": {
"schemes": {
"type": "array",
"items": {"type": "string"},
},
"url": {"type": "string"},
"formats": {"type": "array", "items": {"type": "string"}},
"discovery": {"type": "boolean"},
},
"required": ["schemes", "url"],
},
},
},
"required": ["provider_name", "provider_url", "endpoints"],
},
}