MSC4108 implementation (#17056)

Co-authored-by: Hugh Nimmo-Smith <hughns@element.io>
Co-authored-by: Hugh Nimmo-Smith <hughns@users.noreply.github.com>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
Quentin Gliech 2024-04-25 14:50:12 +02:00 committed by GitHub
parent 646cb6ff24
commit 2e92b718d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1120 additions and 9 deletions

164
Cargo.lock generated
View File

@ -59,6 +59,12 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "bumpalo"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.6.0" version = "1.6.0"
@ -92,9 +98,9 @@ dependencies = [
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.5" version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adfbc57365a37acbd2ebf2b64d7e69bb766e2fea813521ed536f5d0520dcf86c" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [ dependencies = [
"block-buffer", "block-buffer",
"crypto-common", "crypto-common",
@ -117,6 +123,19 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "getrandom"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]] [[package]]
name = "headers" name = "headers"
version = "0.4.0" version = "0.4.0"
@ -182,6 +201,15 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
[[package]]
name = "js-sys"
version = "0.3.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d"
dependencies = [
"wasm-bindgen",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -266,6 +294,12 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.76" version = "1.0.76"
@ -369,6 +403,36 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.2.16" version = "0.2.16"
@ -461,6 +525,17 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.10.0" version = "1.10.0"
@ -489,6 +564,7 @@ name = "synapse"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"base64",
"blake2", "blake2",
"bytes", "bytes",
"headers", "headers",
@ -496,12 +572,15 @@ dependencies = [
"http", "http",
"lazy_static", "lazy_static",
"log", "log",
"mime",
"pyo3", "pyo3",
"pyo3-log", "pyo3-log",
"pythonize", "pythonize",
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
"sha2",
"ulid",
] ]
[[package]] [[package]]
@ -516,6 +595,17 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "ulid"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34778c17965aa2a08913b57e1f34db9b4a63f5de31768b55bf20d2795f921259"
dependencies = [
"getrandom",
"rand",
"web-time",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.5" version = "1.0.5"
@ -534,6 +624,76 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96"
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.36.1" version = "0.36.1"

View File

@ -0,0 +1 @@
Implement the rendezvous mechanism described by MSC4108.

View File

@ -23,11 +23,13 @@ name = "synapse.synapse_rust"
[dependencies] [dependencies]
anyhow = "1.0.63" anyhow = "1.0.63"
base64 = "0.21.7"
bytes = "1.6.0" bytes = "1.6.0"
headers = "0.4.0" headers = "0.4.0"
http = "1.1.0" http = "1.1.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
log = "0.4.17" log = "0.4.17"
mime = "0.3.17"
pyo3 = { version = "0.20.0", features = [ pyo3 = { version = "0.20.0", features = [
"macros", "macros",
"anyhow", "anyhow",
@ -37,8 +39,10 @@ pyo3 = { version = "0.20.0", features = [
pyo3-log = "0.9.0" pyo3-log = "0.9.0"
pythonize = "0.20.0" pythonize = "0.20.0"
regex = "1.6.0" regex = "1.6.0"
sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] } serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85" serde_json = "1.0.85"
ulid = "1.1.2"
[features] [features]
extension-module = ["pyo3/extension-module"] extension-module = ["pyo3/extension-module"]

View File

@ -7,6 +7,7 @@ pub mod errors;
pub mod events; pub mod events;
pub mod http; pub mod http;
pub mod push; pub mod push;
pub mod rendezvous;
lazy_static! { lazy_static! {
static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init();
@ -45,6 +46,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
acl::register_module(py, m)?; acl::register_module(py, m)?;
push::register_module(py, m)?; push::register_module(py, m)?;
events::register_module(py, m)?; events::register_module(py, m)?;
rendezvous::register_module(py, m)?;
Ok(()) Ok(())
} }

315
rust/src/rendezvous/mod.rs Normal file
View File

@ -0,0 +1,315 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2024 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*
*/
use std::{
collections::{BTreeMap, HashMap},
time::{Duration, SystemTime},
};
use bytes::Bytes;
use headers::{
AccessControlAllowOrigin, AccessControlExposeHeaders, CacheControl, ContentLength, ContentType,
HeaderMapExt, IfMatch, IfNoneMatch, Pragma,
};
use http::{header::ETAG, HeaderMap, Response, StatusCode, Uri};
use mime::Mime;
use pyo3::{
exceptions::PyValueError, pyclass, pymethods, types::PyModule, Py, PyAny, PyObject, PyResult,
Python, ToPyObject,
};
use ulid::Ulid;
use self::session::Session;
use crate::{
errors::{NotFoundError, SynapseError},
http::{http_request_from_twisted, http_response_to_twisted, HeaderMapPyExt},
};
mod session;
// n.b. Because OPTIONS requests are handled by the Python code, we don't need to set Access-Control-Allow-Headers.
fn prepare_headers(headers: &mut HeaderMap, session: &Session) {
headers.typed_insert(AccessControlAllowOrigin::ANY);
headers.typed_insert(AccessControlExposeHeaders::from_iter([ETAG]));
headers.typed_insert(Pragma::no_cache());
headers.typed_insert(CacheControl::new().with_no_store());
headers.typed_insert(session.etag());
headers.typed_insert(session.expires());
headers.typed_insert(session.last_modified());
}
#[pyclass]
struct RendezvousHandler {
base: Uri,
clock: PyObject,
sessions: BTreeMap<Ulid, Session>,
capacity: usize,
max_content_length: u64,
ttl: Duration,
}
impl RendezvousHandler {
/// Check the input headers of a request which sets data for a session, and return the content type.
fn check_input_headers(&self, headers: &HeaderMap) -> PyResult<Mime> {
let ContentLength(content_length) = headers.typed_get_required()?;
if content_length > self.max_content_length {
return Err(SynapseError::new(
StatusCode::PAYLOAD_TOO_LARGE,
"Payload too large".to_owned(),
"M_TOO_LARGE",
None,
None,
));
}
let content_type: ContentType = headers.typed_get_required()?;
// Content-Type must be text/plain
if content_type != ContentType::text() {
return Err(SynapseError::new(
StatusCode::BAD_REQUEST,
"Content-Type must be text/plain".to_owned(),
"M_INVALID_PARAM",
None,
None,
));
}
Ok(content_type.into())
}
/// Evict expired sessions and remove the oldest sessions until we're under the capacity.
fn evict(&mut self, now: SystemTime) {
// First remove all the entries which expired
self.sessions.retain(|_, session| !session.expired(now));
// Then we remove the oldest entires until we're under the limit
while self.sessions.len() > self.capacity {
self.sessions.pop_first();
}
}
}
#[pymethods]
impl RendezvousHandler {
#[new]
#[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=60*1000))]
fn new(
py: Python<'_>,
homeserver: &PyAny,
capacity: usize,
max_content_length: u64,
eviction_interval: u64,
ttl: u64,
) -> PyResult<Py<Self>> {
let base: String = homeserver
.getattr("config")?
.getattr("server")?
.getattr("public_baseurl")?
.extract()?;
let base = Uri::try_from(format!("{base}_synapse/client/rendezvous"))
.map_err(|_| PyValueError::new_err("Invalid base URI"))?;
let clock = homeserver.call_method0("get_clock")?.to_object(py);
// Construct a Python object so that we can get a reference to the
// evict method and schedule it to run.
let self_ = Py::new(
py,
Self {
base,
clock,
sessions: BTreeMap::new(),
capacity,
max_content_length,
ttl: Duration::from_millis(ttl),
},
)?;
let evict = self_.getattr(py, "_evict")?;
homeserver.call_method0("get_clock")?.call_method(
"looping_call",
(evict, eviction_interval),
None,
)?;
Ok(self_)
}
fn _evict(&mut self, py: Python<'_>) -> PyResult<()> {
let clock = self.clock.as_ref(py);
let now: u64 = clock.call_method0("time_msec")?.extract()?;
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
self.evict(now);
Ok(())
}
fn handle_post(&mut self, py: Python<'_>, twisted_request: &PyAny) -> PyResult<()> {
let request = http_request_from_twisted(twisted_request)?;
let content_type = self.check_input_headers(request.headers())?;
let clock = self.clock.as_ref(py);
let now: u64 = clock.call_method0("time_msec")?.extract()?;
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
// We trigger an immediate eviction if we're at 2x the capacity
if self.sessions.len() >= self.capacity * 2 {
self.evict(now);
}
// Generate a new ULID for the session from the current time.
let id = Ulid::from_datetime(now);
let uri = format!("{base}/{id}", base = self.base);
let body = request.into_body();
let session = Session::new(body, content_type, now, self.ttl);
let response = serde_json::json!({
"url": uri,
})
.to_string();
let mut response = Response::new(response.as_bytes());
*response.status_mut() = StatusCode::CREATED;
response.headers_mut().typed_insert(ContentType::json());
prepare_headers(response.headers_mut(), &session);
http_response_to_twisted(twisted_request, response)?;
self.sessions.insert(id, session);
Ok(())
}
fn handle_get(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
let request = http_request_from_twisted(twisted_request)?;
let if_none_match: Option<IfNoneMatch> = request.headers().typed_get_optional()?;
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
let session = self
.sessions
.get(&id)
.filter(|s| !s.expired(now))
.ok_or_else(NotFoundError::new)?;
if let Some(if_none_match) = if_none_match {
if !if_none_match.precondition_passes(&session.etag()) {
let mut response = Response::new(Bytes::new());
*response.status_mut() = StatusCode::NOT_MODIFIED;
prepare_headers(response.headers_mut(), session);
http_response_to_twisted(twisted_request, response)?;
return Ok(());
}
}
let mut response = Response::new(session.data());
*response.status_mut() = StatusCode::OK;
let headers = response.headers_mut();
prepare_headers(headers, session);
headers.typed_insert(session.content_type());
headers.typed_insert(session.content_length());
http_response_to_twisted(twisted_request, response)?;
Ok(())
}
fn handle_put(&mut self, py: Python<'_>, twisted_request: &PyAny, id: &str) -> PyResult<()> {
let request = http_request_from_twisted(twisted_request)?;
let content_type = self.check_input_headers(request.headers())?;
let if_match: IfMatch = request.headers().typed_get_required()?;
let data = request.into_body();
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
let session = self
.sessions
.get_mut(&id)
.filter(|s| !s.expired(now))
.ok_or_else(NotFoundError::new)?;
if !if_match.precondition_passes(&session.etag()) {
let mut headers = HeaderMap::new();
prepare_headers(&mut headers, session);
let mut additional_fields = HashMap::with_capacity(1);
additional_fields.insert(
String::from("org.matrix.msc4108.errcode"),
String::from("M_CONCURRENT_WRITE"),
);
return Err(SynapseError::new(
StatusCode::PRECONDITION_FAILED,
"ETag does not match".to_owned(),
"M_UNKNOWN", // Would be M_CONCURRENT_WRITE
Some(additional_fields),
Some(headers),
));
}
session.update(data, content_type, now);
let mut response = Response::new(Bytes::new());
*response.status_mut() = StatusCode::ACCEPTED;
prepare_headers(response.headers_mut(), session);
http_response_to_twisted(twisted_request, response)?;
Ok(())
}
fn handle_delete(&mut self, twisted_request: &PyAny, id: &str) -> PyResult<()> {
let _request = http_request_from_twisted(twisted_request)?;
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?;
let mut response = Response::new(Bytes::new());
*response.status_mut() = StatusCode::NO_CONTENT;
response
.headers_mut()
.typed_insert(AccessControlAllowOrigin::ANY);
http_response_to_twisted(twisted_request, response)?;
Ok(())
}
}
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "rendezvous")?;
child_module.add_class::<RendezvousHandler>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import rendezvous` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.rendezvous", child_module)?;
Ok(())
}

View File

@ -0,0 +1,91 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2024 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/
use std::time::{Duration, SystemTime};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use bytes::Bytes;
use headers::{ContentLength, ContentType, ETag, Expires, LastModified};
use mime::Mime;
use sha2::{Digest, Sha256};
/// A single session, containing data, metadata, and expiry information.
pub struct Session {
hash: [u8; 32],
data: Bytes,
content_type: Mime,
last_modified: SystemTime,
expires: SystemTime,
}
impl Session {
/// Create a new session with the given data, content type, and time-to-live.
pub fn new(data: Bytes, content_type: Mime, now: SystemTime, ttl: Duration) -> Self {
let hash = Sha256::digest(&data).into();
Self {
hash,
data,
content_type,
expires: now + ttl,
last_modified: now,
}
}
/// Returns true if the session has expired at the given time.
pub fn expired(&self, now: SystemTime) -> bool {
self.expires <= now
}
/// Update the session with new data, content type, and last modified time.
pub fn update(&mut self, data: Bytes, content_type: Mime, now: SystemTime) {
self.hash = Sha256::digest(&data).into();
self.data = data;
self.content_type = content_type;
self.last_modified = now;
}
/// Returns the Content-Type header of the session.
pub fn content_type(&self) -> ContentType {
self.content_type.clone().into()
}
/// Returns the Content-Length header of the session.
pub fn content_length(&self) -> ContentLength {
ContentLength(self.data.len() as _)
}
/// Returns the ETag header of the session.
pub fn etag(&self) -> ETag {
let encoded = URL_SAFE_NO_PAD.encode(self.hash);
// SAFETY: Base64 encoding is URL-safe, so ETag-safe
format!("\"{encoded}\"")
.parse()
.expect("base64-encoded hash should be URL-safe")
}
/// Returns the Last-Modified header of the session.
pub fn last_modified(&self) -> LastModified {
self.last_modified.into()
}
/// Returns the Expires header of the session.
pub fn expires(&self) -> Expires {
self.expires.into()
}
/// Returns the current data stored in the session.
pub fn data(&self) -> Bytes {
self.data.clone()
}
}

View File

@ -413,12 +413,22 @@ class ExperimentalConfig(Config):
) )
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
self.msc4108_enabled = experimental.get("msc4108_enabled", False)
self.msc4108_delegation_endpoint: Optional[str] = experimental.get( self.msc4108_delegation_endpoint: Optional[str] = experimental.get(
"msc4108_delegation_endpoint", None "msc4108_delegation_endpoint", None
) )
if self.msc4108_delegation_endpoint is not None and not self.msc3861.enabled: if (
self.msc4108_enabled or self.msc4108_delegation_endpoint is not None
) and not self.msc3861.enabled:
raise ConfigError( raise ConfigError(
"MSC4108 requires MSC3861 to be enabled", "MSC4108 requires MSC3861 to be enabled",
("experimental", "msc4108_delegation_endpoint"), ("experimental", "msc4108_delegation_endpoint"),
) )
if self.msc4108_delegation_endpoint is not None and self.msc4108_enabled:
raise ConfigError(
"You cannot have MSC4108 both enabled and delegated at the same time",
("experimental", "msc4108_delegation_endpoint"),
)

View File

@ -909,8 +909,9 @@ def set_cors_headers(request: "SynapseRequest") -> None:
request.setHeader( request.setHeader(
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
) )
if request.path is not None and request.path.startswith( if request.path is not None and (
b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous" request.path == b"/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
or request.path.startswith(b"/_synapse/client/rendezvous")
): ):
request.setHeader( request.setHeader(
b"Access-Control-Allow-Headers", b"Access-Control-Allow-Headers",

View File

@ -97,9 +97,25 @@ class MSC4108DelegationRendezvousServlet(RestServlet):
) )
class MSC4108RendezvousServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc4108/rendezvous$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer") -> None:
super().__init__()
self._handler = hs.get_rendezvous_handler()
def on_POST(self, request: SynapseRequest) -> None:
self._handler.handle_post(request)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3886_endpoint is not None: if hs.config.experimental.msc3886_endpoint is not None:
MSC3886RendezvousServlet(hs).register(http_server) MSC3886RendezvousServlet(hs).register(http_server)
if hs.config.experimental.msc4108_enabled:
MSC4108RendezvousServlet(hs).register(http_server)
if hs.config.experimental.msc4108_delegation_endpoint is not None: if hs.config.experimental.msc4108_delegation_endpoint is not None:
MSC4108DelegationRendezvousServlet(hs).register(http_server) MSC4108DelegationRendezvousServlet(hs).register(http_server)

View File

@ -141,8 +141,13 @@ class VersionsRestServlet(RestServlet):
# Allows clients to handle push for encrypted events. # Allows clients to handle push for encrypted events.
"org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events, "org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events,
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code # MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
"org.matrix.msc4108": self.config.experimental.msc4108_delegation_endpoint "org.matrix.msc4108": (
is not None, self.config.experimental.msc4108_enabled
or (
self.config.experimental.msc4108_delegation_endpoint
is not None
)
),
}, },
}, },
) )

View File

@ -26,6 +26,7 @@ from twisted.web.resource import Resource
from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.synapse.client.sso_register import SsoRegisterResource
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
@ -76,6 +77,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
# To be removed in Synapse v1.32.0. # To be removed in Synapse v1.32.0.
resources["/_matrix/saml2"] = res resources["/_matrix/saml2"] = res
if hs.config.experimental.msc4108_enabled:
resources["/_synapse/client/rendezvous"] = MSC4108RendezvousSessionResource(hs)
return resources return resources

View File

@ -0,0 +1,58 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
#
import logging
from typing import TYPE_CHECKING, List
from synapse.api.errors import UnrecognizedRequestError
from synapse.http.server import DirectServeJsonResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MSC4108RendezvousSessionResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs: "HomeServer") -> None:
super().__init__()
self._handler = hs.get_rendezvous_handler()
async def _async_render_GET(self, request: SynapseRequest) -> None:
postpath: List[bytes] = request.postpath # type: ignore
if len(postpath) != 1:
raise UnrecognizedRequestError()
session_id = postpath[0].decode("ascii")
self._handler.handle_get(request, session_id)
def _async_render_PUT(self, request: SynapseRequest) -> None:
postpath: List[bytes] = request.postpath # type: ignore
if len(postpath) != 1:
raise UnrecognizedRequestError()
session_id = postpath[0].decode("ascii")
self._handler.handle_put(request, session_id)
def _async_render_DELETE(self, request: SynapseRequest) -> None:
postpath: List[bytes] = request.postpath # type: ignore
if len(postpath) != 1:
raise UnrecognizedRequestError()
session_id = postpath[0].decode("ascii")
self._handler.handle_delete(request, session_id)

View File

@ -143,6 +143,7 @@ from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases from synapse.storage import Databases
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.synapse_rust.rendezvous import RendezvousHandler
from synapse.types import DomainSpecificString, ISynapseReactor from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
@ -859,6 +860,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_room_forgetter_handler(self) -> RoomForgetterHandler: def get_room_forgetter_handler(self) -> RoomForgetterHandler:
return RoomForgetterHandler(self) return RoomForgetterHandler(self)
@cache_in_self
def get_rendezvous_handler(self) -> RendezvousHandler:
return RendezvousHandler(self)
@cache_in_self @cache_in_self
def get_outbound_redis_connection(self) -> "ConnectionHandler": def get_outbound_redis_connection(self) -> "ConnectionHandler":
""" """

View File

@ -0,0 +1,30 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
from twisted.web.iweb import IRequest
from synapse.server import HomeServer
class RendezvousHandler:
def __init__(
self,
homeserver: HomeServer,
/,
capacity: int = 100,
max_content_length: int = 4 * 1024, # MSC4108 specifies 4KB
eviction_interval: int = 60 * 1000,
ttl: int = 60 * 1000,
) -> None: ...
def handle_post(self, request: IRequest) -> None: ...
def handle_get(self, request: IRequest, session_id: str) -> None: ...
def handle_put(self, request: IRequest, session_id: str) -> None: ...
def handle_delete(self, request: IRequest, session_id: str) -> None: ...

View File

@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3. # This file is licensed under the Affero General Public License (AGPL) version 3.
# #
# Copyright 2022 The Matrix.org Foundation C.I.C. # Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd # Copyright (C) 2023-2024 New Vector, Ltd
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as # it under the terms of the GNU Affero General Public License as
@ -19,9 +19,14 @@
# #
# #
from typing import Dict
from urllib.parse import urlparse
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.rest.client import rendezvous from synapse.rest.client import rendezvous
from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -42,6 +47,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
return self.hs return self.hs
def create_resource_dict(self) -> Dict[str, Resource]:
return {
**super().create_resource_dict(),
"/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs),
}
def test_disabled(self) -> None: def test_disabled(self) -> None:
channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None) channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404) self.assertEqual(channel.code, 404)
@ -75,3 +86,391 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None) channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 307) self.assertEqual(channel.code, 307)
self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"]) self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"])
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_enabled": True,
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108(self) -> None:
"""
Test the MSC4108 rendezvous endpoint, including:
- Creating a session
- Getting the data back
- Updating the data
- Deleting the data
- ETag handling
"""
# We can post arbitrary data to the endpoint
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"])
headers = dict(channel.headers.getAllRawHeaders())
self.assertIn(b"ETag", headers)
self.assertIn(b"Expires", headers)
self.assertEqual(headers[b"Content-Type"], [b"application/json"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertIn("url", channel.json_body)
self.assertTrue(channel.json_body["url"].startswith("https://"))
url = urlparse(channel.json_body["url"])
session_endpoint = url.path
etag = headers[b"ETag"][0]
# We can get the data back
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 200)
headers = dict(channel.headers.getAllRawHeaders())
self.assertEqual(headers[b"ETag"], [etag])
self.assertIn(b"Expires", headers)
self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertEqual(channel.text_body, "foo=bar")
# We can make sure the data hasn't changed
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
custom_headers=[("If-None-Match", etag)],
)
self.assertEqual(channel.code, 304)
# We can update the data
channel = self.make_request(
"PUT",
session_endpoint,
"foo=baz",
content_type=b"text/plain",
access_token=None,
custom_headers=[("If-Match", etag)],
)
self.assertEqual(channel.code, 202)
headers = dict(channel.headers.getAllRawHeaders())
old_etag = etag
new_etag = headers[b"ETag"][0]
# If we try to update it again with the old etag, it should fail
channel = self.make_request(
"PUT",
session_endpoint,
"bar=baz",
content_type=b"text/plain",
access_token=None,
custom_headers=[("If-Match", old_etag)],
)
self.assertEqual(channel.code, 412)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN")
self.assertEqual(
channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE"
)
# If we try to get with the old etag, we should get the updated data
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
custom_headers=[("If-None-Match", old_etag)],
)
self.assertEqual(channel.code, 200)
headers = dict(channel.headers.getAllRawHeaders())
self.assertEqual(headers[b"ETag"], [new_etag])
self.assertEqual(channel.text_body, "foo=baz")
# We can delete the data
channel = self.make_request(
"DELETE",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 204)
# If we try to get the data again, it should fail
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_enabled": True,
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108_expiration(self) -> None:
"""
Test that entries are evicted after a TTL.
"""
# Start a new session
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
session_endpoint = urlparse(channel.json_body["url"]).path
# Sanity check that we can get the data back
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.text_body, "foo=bar")
# Advance the clock, TTL of entries is 1 minute
self.reactor.advance(60)
# Get the data back, it should be gone
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 404)
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_enabled": True,
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108_capacity(self) -> None:
"""
Test that a capacity limit is enforced on the rendezvous sessions, as old
entries are evicted at an interval when the limit is reached.
"""
# Start a new session
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
session_endpoint = urlparse(channel.json_body["url"]).path
# Sanity check that we can get the data back
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.text_body, "foo=bar")
# Start a lot of new sessions
for _ in range(100):
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
# Get the data back, it should still be there, as the eviction hasn't run yet
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 200)
# Advance the clock, as it will trigger the eviction
self.reactor.advance(1)
# Get the data back, it should be gone
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_enabled": True,
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108_hard_capacity(self) -> None:
"""
Test that a hard capacity limit is enforced on the rendezvous sessions, as old
entries are evicted immediately when the limit is reached.
"""
# Start a new session
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
session_endpoint = urlparse(channel.json_body["url"]).path
# We advance the clock to make sure that this entry is the "lowest" in the session list
self.reactor.advance(1)
# Sanity check that we can get the data back
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.text_body, "foo=bar")
# Start a lot of new sessions
for _ in range(200):
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
# Get the data back, it should already be gone as we hit the hard limit
channel = self.make_request(
"GET",
session_endpoint,
access_token=None,
)
self.assertEqual(channel.code, 404)
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
"disable_registration": True,
"experimental_features": {
"msc4108_enabled": True,
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
"client_id": "client_id",
"client_auth_method": "client_secret_post",
"client_secret": "client_secret",
"admin_token": "admin_token_value",
},
},
}
)
def test_msc4108_content_type(self) -> None:
"""
Test that the content-type is restricted to text/plain.
"""
# We cannot post invalid content-type arbitrary data to the endpoint
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_is_form=True,
access_token=None,
)
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
# Make a valid request
channel = self.make_request(
"POST",
msc4108_endpoint,
"foo=bar",
content_type=b"text/plain",
access_token=None,
)
self.assertEqual(channel.code, 201)
url = urlparse(channel.json_body["url"])
session_endpoint = url.path
headers = dict(channel.headers.getAllRawHeaders())
etag = headers[b"ETag"][0]
# We can't update the data with invalid content-type
channel = self.make_request(
"PUT",
session_endpoint,
"foo=baz",
content_is_form=True,
access_token=None,
custom_headers=[("If-Match", etag)],
)
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")

View File

@ -351,6 +351,7 @@ def make_request(
request: Type[Request] = SynapseRequest, request: Type[Request] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_type: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[Iterable[CustomHeaderType]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@ -373,6 +374,8 @@ def make_request(
with the usual REST API path, if it doesn't contain it. with the usual REST API path, if it doesn't contain it.
federation_auth_origin: if set to not-None, we will add a fake federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name. Authorization header pretenting to be the given server name.
content_type: The content-type to use for the request. If not set then will default to
application/json unless content_is_form is true.
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
await_result: whether to wait for the request to complete rendering. If true, await_result: whether to wait for the request to complete rendering. If true,
@ -436,7 +439,9 @@ def make_request(
) )
if content: if content:
if content_is_form: if content_type is not None:
req.requestHeaders.addRawHeader(b"Content-Type", content_type)
elif content_is_form:
req.requestHeaders.addRawHeader( req.requestHeaders.addRawHeader(
b"Content-Type", b"application/x-www-form-urlencoded" b"Content-Type", b"application/x-www-form-urlencoded"
) )

View File

@ -523,6 +523,7 @@ class HomeserverTestCase(TestCase):
request: Type[Request] = SynapseRequest, request: Type[Request] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_type: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[Iterable[CustomHeaderType]] = None, custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@ -541,6 +542,9 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it. with the usual REST API path, if it doesn't contain it.
federation_auth_origin: if set to not-None, we will add a fake federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name. Authorization header pretenting to be the given server name.
content_type: The content-type to use for the request. If not set then will default to
application/json unless content_is_form is true.
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
@ -566,6 +570,7 @@ class HomeserverTestCase(TestCase):
request, request,
shorthand, shorthand,
federation_auth_origin, federation_auth_origin,
content_type,
content_is_form, content_is_form,
await_result, await_result,
custom_headers, custom_headers,