Update black, and run auto formatting over the codebase (#9381)

- Update black version to the latest
 - Run black auto formatting over the codebase
    - Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md)
 - Update `code_style.md` docs around installing black to use the correct version
This commit is contained in:
Eric Eastwood 2021-02-16 16:32:34 -06:00 committed by GitHub
parent 5636e597c3
commit 0a00b7ff14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
271 changed files with 2802 additions and 1713 deletions

1
changelog.d/9381.misc Normal file
View File

@ -0,0 +1 @@
Update the version of black used to 20.8b1.

View File

@ -23,8 +23,7 @@ from twisted.web.http_headers import Headers
class HttpClient: class HttpClient:
""" Interface for talking json over http """Interface for talking json over http"""
"""
def put_json(self, url, data): def put_json(self, url, data):
"""Sends the specifed json data using PUT """Sends the specifed json data using PUT
@ -87,8 +86,7 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}): def _create_put_request(self, url, json_data, headers_dict={}):
""" Wrapper of _create_request to issue a PUT request """Wrapper of _create_request to issue a PUT request"""
"""
if "Content-Type" not in headers_dict: if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@ -98,8 +96,7 @@ class TwistedHttpClient(HttpClient):
) )
def _create_get_request(self, url, headers_dict={}): def _create_get_request(self, url, headers_dict={}):
""" Wrapper of _create_request to issue a GET request """Wrapper of _create_request to issue a GET request"""
"""
return self._create_request("GET", url, headers_dict=headers_dict) return self._create_request("GET", url, headers_dict=headers_dict)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -127,8 +124,7 @@ class TwistedHttpClient(HttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}): def _create_request(self, method, url, producer=None, headers_dict={}):
""" Creates and sends a request to the given url """Creates and sends a request to the given url"""
"""
headers_dict["User-Agent"] = ["Synapse Cmd Client"] headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5 retries_left = 5
@ -185,8 +181,7 @@ class _RawProducer:
class _JsonProducer: class _JsonProducer:
""" Used by the twisted http client to create the HTTP body from json """Used by the twisted http client to create the HTTP body from json"""
"""
def __init__(self, jsn): def __init__(self, jsn):
self.data = jsn self.data = jsn

View File

@ -63,8 +63,7 @@ class CursesStdIO:
self.redraw() self.redraw()
def redraw(self): def redraw(self):
""" method for redisplaying lines """method for redisplaying lines based on internal list of lines"""
based on internal list of lines """
self.stdscr.clear() self.stdscr.clear()
self.paintStatus(self.statusText) self.paintStatus(self.statusText)

View File

@ -68,8 +68,7 @@ class InputOutput:
self.server = server self.server = server
def on_line(self, line): def on_line(self, line):
""" This is where we process commands. """This is where we process commands."""
"""
try: try:
m = re.match(r"^join (\S+)$", line) m = re.match(r"^join (\S+)$", line)
@ -148,8 +147,7 @@ class Room:
self.have_got_metadata = False self.have_got_metadata = False
def add_participant(self, participant): def add_participant(self, participant):
""" Someone has joined the room """Someone has joined the room"""
"""
self.participants.add(participant) self.participants.add(participant)
self.invited.discard(participant) self.invited.discard(participant)
@ -160,8 +158,7 @@ class Room:
self.oldest_server = server self.oldest_server = server
def add_invited(self, invitee): def add_invited(self, invitee):
""" Someone has been invited to the room """Someone has been invited to the room"""
"""
self.invited.add(invitee) self.invited.add(invitee)
self.servers.add(origin_from_ucid(invitee)) self.servers.add(origin_from_ucid(invitee))
@ -181,8 +178,7 @@ class HomeServer(ReplicationHandler):
self.output = output self.output = output
def on_receive_pdu(self, pdu): def on_receive_pdu(self, pdu):
""" We just received a PDU """We just received a PDU"""
"""
pdu_type = pdu.pdu_type pdu_type = pdu.pdu_type
if pdu_type == "sy.room.message": if pdu_type == "sy.room.message":
@ -199,23 +195,20 @@ class HomeServer(ReplicationHandler):
) )
def _on_message(self, pdu): def _on_message(self, pdu):
""" We received a message """We received a message"""
"""
self.output.print_line( self.output.print_line(
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
) )
def _on_join(self, context, joinee): def _on_join(self, context, joinee):
""" Someone has joined a room, either a remote user or a local user """Someone has joined a room, either a remote user or a local user"""
"""
room = self._get_or_create_room(context) room = self._get_or_create_room(context)
room.add_participant(joinee) room.add_participant(joinee)
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
def _on_invite(self, origin, context, invitee): def _on_invite(self, origin, context, invitee):
""" Someone has been invited """Someone has been invited"""
"""
room = self._get_or_create_room(context) room = self._get_or_create_room(context)
room.add_invited(invitee) room.add_invited(invitee)
@ -228,8 +221,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_message(self, room_name, sender, body): def send_message(self, room_name, sender, body):
""" Send a message to a room! """Send a message to a room!"""
"""
destinations = yield self.get_servers_for_context(room_name) destinations = yield self.get_servers_for_context(room_name)
try: try:
@ -247,8 +239,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room(self, room_name, sender, joinee): def join_room(self, room_name, sender, joinee):
""" Join a room! """Join a room!"""
"""
self._on_join(room_name, joinee) self._on_join(room_name, joinee)
destinations = yield self.get_servers_for_context(room_name) destinations = yield self.get_servers_for_context(room_name)
@ -269,8 +260,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def invite_to_room(self, room_name, sender, invitee): def invite_to_room(self, room_name, sender, invitee):
""" Invite someone to a room! """Invite someone to a room!"""
"""
self._on_invite(self.server_name, room_name, invitee) self._on_invite(self.server_name, room_name, invitee)
destinations = yield self.get_servers_for_context(room_name) destinations = yield self.get_servers_for_context(room_name)

View File

@ -193,15 +193,12 @@ class TrivialXmppClient:
time.sleep(7) time.sleep(7)
print("SSRC spammer started") print("SSRC spammer started")
while self.running: while self.running:
ssrcMsg = ( ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % {
"<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
% {
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
"nick": self.userId, "nick": self.userId,
"assrc": self.ssrcs["audio"], "assrc": self.ssrcs["audio"],
"vssrc": self.ssrcs["video"], "vssrc": self.ssrcs["video"],
} }
)
res = self.sendIq(ssrcMsg) res = self.sendIq(ssrcMsg)
print("reply from ssrc announce: ", res) print("reply from ssrc announce: ", res)
time.sleep(10) time.sleep(10)

View File

@ -8,16 +8,16 @@ errors in code.
The necessary tools are detailed below. The necessary tools are detailed below.
First install them with:
pip install -e ".[lint,mypy]"
- **black** - **black**
The Synapse codebase uses [black](https://pypi.org/project/black/) The Synapse codebase uses [black](https://pypi.org/project/black/)
as an opinionated code formatter, ensuring all comitted code is as an opinionated code formatter, ensuring all comitted code is
properly formatted. properly formatted.
First install `black` with:
pip install --upgrade black
Have `black` auto-format your code (it shouldn't change any Have `black` auto-format your code (it shouldn't change any
functionality) with: functionality) with:
@ -28,10 +28,6 @@ The necessary tools are detailed below.
`flake8` is a code checking tool. We require code to pass `flake8` `flake8` is a code checking tool. We require code to pass `flake8`
before being merged into the codebase. before being merged into the codebase.
Install `flake8` with:
pip install --upgrade flake8 flake8-comprehensions
Check all application and test code with: Check all application and test code with:
flake8 synapse tests flake8 synapse tests
@ -41,10 +37,6 @@ The necessary tools are detailed below.
`isort` ensures imports are nicely formatted, and can suggest and `isort` ensures imports are nicely formatted, and can suggest and
auto-fix issues such as double-importing. auto-fix issues such as double-importing.
Install `isort` with:
pip install --upgrade isort
Auto-fix imports with: Auto-fix imports with:
isort -rc synapse tests isort -rc synapse tests

View File

@ -87,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
signature = signature.copy_modified( signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
) )
return signature return signature

View File

@ -97,7 +97,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
# We pin black so that our tests don't start failing on new releases. # We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [ CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0", "isort==5.7.0",
"black==19.10b0", "black==20.8b1",
"flake8-comprehensions", "flake8-comprehensions",
"flake8", "flake8",
] ]

View File

@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]):
def __reduce__( def __reduce__(
self, self,
) -> Tuple[ ) -> Tuple[
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], Type[SortedDict[_KT, _VT]],
Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
]: ... ]: ...
def __repr__(self) -> str: ... def __repr__(self) -> str: ...
def _check(self) -> None: ... def _check(self) -> None: ...
def islice( def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
reverse=bool,
) -> Iterator[_KT]: ... ) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ... def bisect_left(self, value: _KT) -> int: ...
def bisect_right(self, value: _KT) -> int: ... def bisect_right(self, value: _KT) -> int: ...

View File

@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ... DEFAULT_LOAD_FACTOR: int = ...
def __init__( def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., self,
iterable: Optional[Iterable[_T]] = ...,
key: Optional[_Key[_T]] = ...,
): ... ): ...
# NB: currently mypy does not honour return type, see mypy #3307 # NB: currently mypy does not honour return type, see mypy #3307
@overload @overload
@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]):
def __len__(self) -> int: ... def __len__(self) -> int: ...
def reverse(self) -> None: ... def reverse(self) -> None: ...
def islice( def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
reverse=bool,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def _islice( def _islice(
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, self,
min_pos: int,
min_idx: int,
max_pos: int,
max_idx: int,
reverse: bool,
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def irange( def irange(
self, self,

View File

@ -294,7 +294,10 @@ class Auth:
return user_id, app_service return user_id, app_service
async def get_user_by_access_token( async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False, self,
token: str,
rights: str = "access",
allow_expired: bool = False,
) -> TokenLookupResult: ) -> TokenLookupResult:
"""Validate access token and get user_id from it """Validate access token and get user_id from it
@ -500,7 +503,10 @@ class Auth:
return await self.store.is_server_admin(user) return await self.store.is_server_admin(user)
def compute_auth_events( def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False, self,
event,
current_state_ids: StateMap[str],
for_verification: bool = False,
) -> List[str]: ) -> List[str]:
"""Given an event and current state return the list of event IDs used """Given an event and current state return the list of event IDs used
to auth an event. to auth an event.

View File

@ -128,8 +128,7 @@ class UserTypes:
class RelationTypes: class RelationTypes:
"""The types of relations known to this server. """The types of relations known to this server."""
"""
ANNOTATION = "m.annotation" ANNOTATION = "m.annotation"
REPLACE = "m.replace" REPLACE = "m.replace"

View File

@ -390,8 +390,7 @@ class InvalidCaptchaError(SynapseError):
class LimitExceededError(SynapseError): class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled. """A client has sent too many requests and is being throttled."""
"""
def __init__( def __init__(
self, self,
@ -408,8 +407,7 @@ class LimitExceededError(SynapseError):
class RoomKeysVersionError(SynapseError): class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store """A client has tried to upload to a non-current version of the room_keys store"""
"""
def __init__(self, current_version: str): def __init__(self, current_version: str):
""" """
@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError):
def __init__(self, msg: str = "Homeserver does not support this room version"): def __init__(self, msg: str = "Homeserver does not support this room version"):
super().__init__( super().__init__(
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, code=400,
msg=msg,
errcode=Codes.UNSUPPORTED_ROOM_VERSION,
) )
@ -461,8 +461,7 @@ class IncompatibleRoomVersionError(SynapseError):
class PasswordRefusedError(SynapseError): class PasswordRefusedError(SynapseError):
"""A password has been refused, either during password reset/change or registration. """A password has been refused, either during password reset/change or registration."""
"""
def __init__( def __init__(
self, self,
@ -470,7 +469,9 @@ class PasswordRefusedError(SynapseError):
errcode: str = Codes.WEAK_PASSWORD, errcode: str = Codes.WEAK_PASSWORD,
): ):
super().__init__( super().__init__(
code=400, msg=msg, errcode=errcode, code=400,
msg=msg,
errcode=errcode,
) )

View File

@ -56,8 +56,7 @@ class UserPresenceState(
@classmethod @classmethod
def default(cls, user_id): def default(cls, user_id):
"""Returns a default presence state. """Returns a default presence state."""
"""
return cls( return cls(
user_id=user_id, user_id=user_id,
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,

View File

@ -313,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
refresh_certificate(hs) refresh_certificate(hs)
# Start the tracer # Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
hs
)
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)
@ -370,8 +368,7 @@ def setup_sentry(hs):
def setup_sdnotify(hs): def setup_sdnotify(hs):
"""Adds process state hooks to tell systemd what we are up to. """Adds process state hooks to tell systemd what we are up to."""
"""
# Tell systemd our state, if we're using it. This will silently fail if # Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd. # we're not using systemd.
@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
class _LimitedHostnameResolver: class _LimitedHostnameResolver:
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
"""
def __init__(self, resolver, max_dns_requests_in_flight): def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver self._resolver = resolver

View File

@ -421,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler):
] ]
async def set_state(self, target_user, state, ignore_status_msg=False): async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user. """Set the presence state of the user."""
"""
presence = state["presence"] presence = state["presence"]
valid_presence = ( valid_presence = (

View File

@ -166,7 +166,10 @@ class ApplicationService:
@cached(num_args=1, cache_context=True) @cached(num_args=1, cache_context=True)
async def matches_user_in_member_list( async def matches_user_in_member_list(
self, room_id: str, store: "DataStore", cache_context: _CacheContext, self,
room_id: str,
store: "DataStore",
cache_context: _CacheContext,
) -> bool: ) -> bool:
"""Check if this service is interested a room based upon it's membership """Check if this service is interested a room based upon it's membership

View File

@ -227,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient):
try: try:
await self.put_json( await self.put_json(
uri=uri, json_body=body, args={"access_token": service.hs_token}, uri=uri,
json_body=body,
args={"access_token": service.hs_token},
) )
sent_transactions_counter.labels(service.id).inc() sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events)) sent_events_counter.labels(service.id).inc(len(events))

View File

@ -224,7 +224,9 @@ class Config:
return self.read_templates([filename])[0] return self.read_templates([filename])[0]
def read_templates( def read_templates(
self, filenames: List[str], custom_template_directory: Optional[str] = None, self,
filenames: List[str],
custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]: ) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables. """Load a list of template files from disk using the given variables.
@ -264,7 +266,10 @@ class Config:
# TODO: switch to synapse.util.templates.build_jinja_env # TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories) loader = jinja2.FileSystemLoader(search_directories)
env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) env = jinja2.Environment(
loader=loader,
autoescape=jinja2.select_autoescape(),
)
# Update the environment with our custom filters # Update the environment with our custom filters
env.filters.update( env.filters.update(
@ -825,8 +830,7 @@ class ShardedWorkerHandlingConfig:
instances = attr.ib(type=List[str]) instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool: def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key. """Whether this instance is responsible for handling the given key."""
"""
# If multiple instances are not defined we always return true # If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1: if not self.instances or len(self.instances) == 1:
return True return True

View File

@ -18,8 +18,7 @@ from ._base import Config
class AuthConfig(Config): class AuthConfig(Config):
"""Password and login configuration """Password and login configuration"""
"""
section = "auth" section = "auth"

View File

@ -207,8 +207,7 @@ class DatabaseConfig(Config):
) )
def get_single_database(self) -> DatabaseConnectionConfig: def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests """Returns the database if there is only one, useful for e.g. tests"""
"""
if not self.databases: if not self.databases:
raise Exception("More than one database exists") raise Exception("More than one database exists")

View File

@ -289,7 +289,8 @@ class EmailConfig(Config):
self.email_notif_template_html, self.email_notif_template_html,
self.email_notif_template_text, self.email_notif_template_text,
) = self.read_templates( ) = self.read_templates(
[notif_template_html, notif_template_text], template_dir, [notif_template_html, notif_template_text],
template_dir,
) )
self.email_notif_for_new_users = email_config.get( self.email_notif_for_new_users = email_config.get(
@ -311,7 +312,8 @@ class EmailConfig(Config):
self.account_validity_template_html, self.account_validity_template_html,
self.account_validity_template_text, self.account_validity_template_text,
) = self.read_templates( ) = self.read_templates(
[expiry_template_html, expiry_template_text], template_dir, [expiry_template_html, expiry_template_text],
template_dir,
) )
subjects_config = email_config.get("subjects", {}) subjects_config = email_config.get("subjects", {})

View File

@ -162,7 +162,10 @@ class LoggingConfig(Config):
) )
logging_group.add_argument( logging_group.add_argument(
"-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, "-f",
"--log-file",
dest="log_file",
help=argparse.SUPPRESS,
) )
def generate_files(self, config, config_dir_path): def generate_files(self, config, config_dir_path):

View File

@ -355,9 +355,10 @@ def _parse_oidc_config_dict(
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {}) ump_config.setdefault("config", {})
(user_mapping_provider_class, user_mapping_provider_config,) = load_module( (
ump_config, config_path + ("user_mapping_provider",) user_mapping_provider_class,
) user_mapping_provider_config,
) = load_module(ump_config, config_path + ("user_mapping_provider",))
# Ensure loaded user mapping module has defined all necessary methods # Ensure loaded user mapping module has defined all necessary methods
required_methods = [ required_methods = [
@ -372,7 +373,11 @@ def _parse_oidc_config_dict(
if missing_methods: if missing_methods:
raise ConfigError( raise ConfigError(
"Class %s is missing required " "Class %s is missing required "
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), "methods: %s"
% (
user_mapping_provider_class,
", ".join(missing_methods),
),
config_path + ("user_mapping_provider", "module"), config_path + ("user_mapping_provider", "module"),
) )

View File

@ -52,7 +52,12 @@ def _6to4(network: IPNetwork) -> IPNetwork:
hex_network = hex(network.first)[2:] hex_network = hex(network.first)[2:]
hex_network = ("0" * (8 - len(hex_network))) + hex_network hex_network = ("0" * (8 - len(hex_network))) + hex_network
return IPNetwork( return IPNetwork(
"2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,) "2002:%s:%s::/%d"
% (
hex_network[:4],
hex_network[4:],
16 + network.prefixlen,
)
) )
@ -254,7 +259,8 @@ class ServerConfig(Config):
# Whether to require sharing a room with a user to retrieve their # Whether to require sharing a room with a user to retrieve their
# profile data # profile data
self.limit_profile_requests_to_users_who_share_rooms = config.get( self.limit_profile_requests_to_users_who_share_rooms = config.get(
"limit_profile_requests_to_users_who_share_rooms", False, "limit_profile_requests_to_users_who_share_rooms",
False,
) )
if "restrict_public_rooms_to_local_users" in config and ( if "restrict_public_rooms_to_local_users" in config and (
@ -614,7 +620,9 @@ class ServerConfig(Config):
if manhole: if manhole:
self.listeners.append( self.listeners.append(
ListenerConfig( ListenerConfig(
port=manhole, bind_addresses=["127.0.0.1"], type="manhole", port=manhole,
bind_addresses=["127.0.0.1"],
type="manhole",
) )
) )
@ -650,7 +658,8 @@ class ServerConfig(Config):
# and letting the client know which email address is bound to an account and # and letting the client know which email address is bound to an account and
# which one isn't. # which one isn't.
self.request_token_inhibit_3pid_errors = config.get( self.request_token_inhibit_3pid_errors = config.get(
"request_token_inhibit_3pid_errors", False, "request_token_inhibit_3pid_errors",
False,
) )
# List of users trialing the new experimental default push rules. This setting is # List of users trialing the new experimental default push rules. This setting is

View File

@ -35,8 +35,7 @@ class SsoAttributeRequirement:
class SSOConfig(Config): class SSOConfig(Config):
"""SSO Configuration """SSO Configuration"""
"""
section = "sso" section = "sso"

View File

@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
@attr.s @attr.s
class InstanceLocationConfig: class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication. """The host and port to talk to an instance via HTTP replication."""
"""
host = attr.ib(type=str) host = attr.ib(type=str)
port = attr.ib(type=int) port = attr.ib(type=int)
@ -54,13 +53,19 @@ class WriterLocations:
) )
typing = attr.ib(default="master", type=str) typing = attr.ib(default="master", type=str)
to_device = attr.ib( to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter, default=["master"],
type=List[str],
converter=_instance_to_list_converter,
) )
account_data = attr.ib( account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter, default=["master"],
type=List[str],
converter=_instance_to_list_converter,
) )
receipts = attr.ib( receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter, default=["master"],
type=List[str],
converter=_instance_to_list_converter,
) )
@ -107,7 +112,9 @@ class WorkerConfig(Config):
if manhole: if manhole:
self.worker_listeners.append( self.worker_listeners.append(
ListenerConfig( ListenerConfig(
port=manhole, bind_addresses=["127.0.0.1"], type="manhole", port=manhole,
bind_addresses=["127.0.0.1"],
type="manhole",
) )
) )

View File

@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def check_redaction( def check_redaction(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
) -> bool: ) -> bool:
"""Check whether the event sender is allowed to redact the target event. """Check whether the event sender is allowed to redact the target event.
@ -459,7 +461,9 @@ def check_redaction(
def _check_power_levels( def _check_power_levels(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
) -> None: ) -> None:
user_list = event.content.get("users", {}) user_list = event.content.get("users", {})
# Validate users # Validate users

View File

@ -98,7 +98,9 @@ class EventBuilder:
return self._state_key is not None return self._state_key is not None
async def build( async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], self,
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
) -> EventBase: ) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event

View File

@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
def _decode_state_dict(input): def _decode_state_dict(input):
"""Decodes a state dict encoded using `_encode_state_dict` above """Decodes a state dict encoded using `_encode_state_dict` above"""
"""
if input is None: if input is None:
return None return None

View File

@ -40,7 +40,8 @@ class ThirdPartyEventRules:
if module is not None: if module is not None:
self.third_party_rules = module( self.third_party_rules = module(
config=config, module_api=hs.get_module_api(), config=config,
module_api=hs.get_module_api(),
) )
async def check_event_allowed( async def check_event_allowed(

View File

@ -750,7 +750,11 @@ class FederationClient(FederationBase):
return resp[1] return resp[1]
async def send_invite( async def send_invite(
self, destination: str, room_id: str, event_id: str, pdu: EventBase, self,
destination: str,
room_id: str,
event_id: str,
pdu: EventBase,
) -> EventBase: ) -> EventBase:
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)

View File

@ -85,7 +85,8 @@ received_queries_counter = Counter(
) )
pdu_process_time = Histogram( pdu_process_time = Histogram(
"synapse_federation_server_pdu_process_time", "Time taken to process an event", "synapse_federation_server_pdu_process_time",
"Time taken to process an event",
) )
@ -373,8 +374,7 @@ class FederationServer(FederationBase):
return pdu_results return pdu_results
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
"""Process the EDUs in a received transaction. """Process the EDUs in a received transaction."""
"""
async def _process_edu(edu_dict): async def _process_edu(edu_dict):
received_edus_counter.inc() received_edus_counter.inc()
@ -437,7 +437,10 @@ class FederationServer(FederationBase):
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
resp = await self._state_ids_resp_cache.wrap( resp = await self._state_ids_resp_cache.wrap(
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, (room_id, event_id),
self._on_state_ids_request_compute,
room_id,
event_id,
) )
return 200, resp return 200, resp
@ -906,13 +909,11 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str): def register_instance_for_edu(self, edu_type: str, instance_name: str):
"""Register that the EDU handler is on a different instance than master. """Register that the EDU handler is on a different instance than master."""
"""
self._edu_type_to_instance[edu_type] = [instance_name] self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
"""Register that the EDU handler is on multiple instances. """Register that the EDU handler is on multiple instances."""
"""
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict):

View File

@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
class TransactionActions: class TransactionActions:
""" Defines persistence actions that relate to handling Transactions. """Defines persistence actions that relate to handling Transactions."""
"""
def __init__(self, datastore): def __init__(self, datastore):
self.store = datastore self.store = datastore
@ -57,8 +56,7 @@ class TransactionActions:
async def set_response( async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None: ) -> None:
"""Persist how we responded to a transaction. """Persist how we responded to a transaction."""
"""
transaction_id = transaction.transaction_id # type: ignore transaction_id = transaction.transaction_id # type: ignore
if not transaction_id: if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id") raise RuntimeError("Cannot persist a transaction with no transaction_id")

View File

@ -468,8 +468,7 @@ class KeyedEduRow(
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
"""Streams EDUs that don't have keys. See KeyedEduRow """Streams EDUs that don't have keys. See KeyedEduRow"""
"""
TypeId = "e" TypeId = "e"
@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows):
# them into the appropriate collection and then send them off. # them into the appropriate collection and then send them off.
buff = ParsedFederationStreamData( buff = ParsedFederationStreamData(
presence=[], presence_destinations=[], keyed_edus={}, edus={}, presence=[],
presence_destinations=[],
keyed_edus={},
edus={},
) )
# Parse the rows in the stream and add to the buffer # Parse the rows in the stream and add to the buffer

View File

@ -328,7 +328,9 @@ class FederationSender:
# to allow us to perform catch-up later on if the remote is unreachable # to allow us to perform catch-up later on if the remote is unreachable
# for a while. # for a while.
await self.store.store_destination_rooms_entries( await self.store.store_destination_rooms_entries(
destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, destinations,
pdu.room_id,
pdu.internal_metadata.stream_ordering,
) )
for destination in destinations: for destination in destinations:
@ -616,8 +618,8 @@ class FederationSender:
last_processed = None # type: Optional[str] last_processed = None # type: Optional[str]
while True: while True:
destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( destinations_to_wake = (
last_processed await self.store.get_catch_up_outstanding_destinations(last_processed)
) )
if not destinations_to_wake: if not destinations_to_wake:

View File

@ -85,7 +85,8 @@ class PerDestinationQueue:
# processing. We have a guard in `attempt_new_transaction` that # processing. We have a guard in `attempt_new_transaction` that
# ensure we don't start sending stuff. # ensure we don't start sending stuff.
logger.error( logger.error(
"Create a per destination queue for %s on wrong worker", destination, "Create a per destination queue for %s on wrong worker",
destination,
) )
self._should_send_on_this_instance = False self._should_send_on_this_instance = False
@ -440,9 +441,11 @@ class PerDestinationQueue:
if first_catch_up_check: if first_catch_up_check:
# first catchup so get last_successful_stream_ordering from database # first catchup so get last_successful_stream_ordering from database
self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( self._last_successful_stream_ordering = (
await self._store.get_destination_last_successful_stream_ordering(
self._destination self._destination
) )
)
if self._last_successful_stream_ordering is None: if self._last_successful_stream_ordering is None:
# if it's still None, then this means we don't have the information # if it's still None, then this means we don't have the information
@ -457,7 +460,8 @@ class PerDestinationQueue:
# get at most 50 catchup room/PDUs # get at most 50 catchup room/PDUs
while True: while True:
event_ids = await self._store.get_catch_up_room_event_ids( event_ids = await self._store.get_catch_up_room_event_ids(
self._destination, self._last_successful_stream_ordering, self._destination,
self._last_successful_stream_ordering,
) )
if not event_ids: if not event_ids:

View File

@ -65,7 +65,10 @@ class TransactionManager:
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
async def send_new_transaction( async def send_new_transaction(
self, destination: str, pdus: List[EventBase], edus: List[Edu], self,
destination: str,
pdus: List[EventBase],
edus: List[Edu],
) -> bool: ) -> bool:
""" """
Args: Args:

View File

@ -551,8 +551,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_profile(self, destination, group_id, requester_user_id): def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile """Get a group profile"""
"""
path = _create_v1_path("/groups/%s/profile", group_id) path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json( return self.client.get_json(
@ -584,8 +583,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_summary(self, destination, group_id, requester_user_id): def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary """Get a group summary"""
"""
path = _create_v1_path("/groups/%s/summary", group_id) path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json( return self.client.get_json(
@ -597,8 +595,7 @@ class TransportLayerClient:
@log_function @log_function
def get_rooms_in_group(self, destination, group_id, requester_user_id): def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group """Get all rooms in a group"""
"""
path = _create_v1_path("/groups/%s/rooms", group_id) path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json( return self.client.get_json(
@ -611,8 +608,7 @@ class TransportLayerClient:
def add_room_to_group( def add_room_to_group(
self, destination, group_id, requester_user_id, room_id, content self, destination, group_id, requester_user_id, room_id, content
): ):
"""Add a room to a group """Add a room to a group"""
"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json( return self.client.post_json(
@ -626,8 +622,7 @@ class TransportLayerClient:
def update_room_in_group( def update_room_in_group(
self, destination, group_id, requester_user_id, room_id, config_key, content self, destination, group_id, requester_user_id, room_id, config_key, content
): ):
"""Update room in group """Update room in group"""
"""
path = _create_v1_path( path = _create_v1_path(
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
) )
@ -641,8 +636,7 @@ class TransportLayerClient:
) )
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group """Remove a room from a group"""
"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json( return self.client.delete_json(
@ -654,8 +648,7 @@ class TransportLayerClient:
@log_function @log_function
def get_users_in_group(self, destination, group_id, requester_user_id): def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group """Get users in a group"""
"""
path = _create_v1_path("/groups/%s/users", group_id) path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json( return self.client.get_json(
@ -667,8 +660,7 @@ class TransportLayerClient:
@log_function @log_function
def get_invited_users_in_group(self, destination, group_id, requester_user_id): def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group """Get users that have been invited to a group"""
"""
path = _create_v1_path("/groups/%s/invited_users", group_id) path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json( return self.client.get_json(
@ -680,8 +672,7 @@ class TransportLayerClient:
@log_function @log_function
def accept_group_invite(self, destination, group_id, user_id, content): def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite """Accept a group invite"""
"""
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json( return self.client.post_json(
@ -690,8 +681,7 @@ class TransportLayerClient:
@log_function @log_function
def join_group(self, destination, group_id, user_id, content): def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group """Attempts to join a group"""
"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json( return self.client.post_json(
@ -702,8 +692,7 @@ class TransportLayerClient:
def invite_to_group( def invite_to_group(
self, destination, group_id, user_id, requester_user_id, content self, destination, group_id, user_id, requester_user_id, content
): ):
"""Invite a user to a group """Invite a user to a group"""
"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json( return self.client.post_json(
@ -730,8 +719,7 @@ class TransportLayerClient:
def remove_user_from_group( def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content self, destination, group_id, requester_user_id, user_id, content
): ):
"""Remove a user from a group """Remove a user from a group"""
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json( return self.client.post_json(
@ -772,8 +760,7 @@ class TransportLayerClient:
def update_group_summary_room( def update_group_summary_room(
self, destination, group_id, user_id, room_id, category_id, content self, destination, group_id, user_id, room_id, category_id, content
): ):
"""Update a room entry in a group summary """Update a room entry in a group summary"""
"""
if category_id: if category_id:
path = _create_v1_path( path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s", "/groups/%s/summary/categories/%s/rooms/%s",
@ -796,8 +783,7 @@ class TransportLayerClient:
def delete_group_summary_room( def delete_group_summary_room(
self, destination, group_id, user_id, room_id, category_id self, destination, group_id, user_id, room_id, category_id
): ):
"""Delete a room entry in a group summary """Delete a room entry in a group summary"""
"""
if category_id: if category_id:
path = _create_v1_path( path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s", "/groups/%s/summary/categories/%s/rooms/%s",
@ -817,8 +803,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_categories(self, destination, group_id, requester_user_id): def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group """Get all categories in a group"""
"""
path = _create_v1_path("/groups/%s/categories", group_id) path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json( return self.client.get_json(
@ -830,8 +815,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_category(self, destination, group_id, requester_user_id, category_id): def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group """Get category info in a group"""
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json( return self.client.get_json(
@ -845,8 +829,7 @@ class TransportLayerClient:
def update_group_category( def update_group_category(
self, destination, group_id, requester_user_id, category_id, content self, destination, group_id, requester_user_id, category_id, content
): ):
"""Update a category in a group """Update a category in a group"""
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json( return self.client.post_json(
@ -861,8 +844,7 @@ class TransportLayerClient:
def delete_group_category( def delete_group_category(
self, destination, group_id, requester_user_id, category_id self, destination, group_id, requester_user_id, category_id
): ):
"""Delete a category in a group """Delete a category in a group"""
"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json( return self.client.delete_json(
@ -874,8 +856,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_roles(self, destination, group_id, requester_user_id): def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group """Get all roles in a group"""
"""
path = _create_v1_path("/groups/%s/roles", group_id) path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json( return self.client.get_json(
@ -887,8 +868,7 @@ class TransportLayerClient:
@log_function @log_function
def get_group_role(self, destination, group_id, requester_user_id, role_id): def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info """Get a roles info"""
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json( return self.client.get_json(
@ -902,8 +882,7 @@ class TransportLayerClient:
def update_group_role( def update_group_role(
self, destination, group_id, requester_user_id, role_id, content self, destination, group_id, requester_user_id, role_id, content
): ):
"""Update a role in a group """Update a role in a group"""
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json( return self.client.post_json(
@ -916,8 +895,7 @@ class TransportLayerClient:
@log_function @log_function
def delete_group_role(self, destination, group_id, requester_user_id, role_id): def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group """Delete a role in a group"""
"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json( return self.client.delete_json(
@ -931,8 +909,7 @@ class TransportLayerClient:
def update_group_summary_user( def update_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id, content self, destination, group_id, requester_user_id, user_id, role_id, content
): ):
"""Update a users entry in a group """Update a users entry in a group"""
"""
if role_id: if role_id:
path = _create_v1_path( path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@ -950,8 +927,7 @@ class TransportLayerClient:
@log_function @log_function
def set_group_join_policy(self, destination, group_id, requester_user_id, content): def set_group_join_policy(self, destination, group_id, requester_user_id, content):
"""Sets the join policy for a group """Sets the join policy for a group"""
"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json( return self.client.put_json(
@ -966,8 +942,7 @@ class TransportLayerClient:
def delete_group_summary_user( def delete_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id self, destination, group_id, requester_user_id, user_id, role_id
): ):
"""Delete a users entry in a group """Delete a users entry in a group"""
"""
if role_id: if role_id:
path = _create_v1_path( path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@ -983,8 +958,7 @@ class TransportLayerClient:
) )
def bulk_get_publicised_groups(self, destination, user_ids): def bulk_get_publicised_groups(self, destination, user_ids):
"""Get the groups a list of users are publicising """Get the groups a list of users are publicising"""
"""
path = _create_v1_path("/get_groups_publicised") path = _create_v1_path("/get_groups_publicised")

View File

@ -364,7 +364,10 @@ class BaseFederationServlet:
continue continue
server.register_paths( server.register_paths(
method, (pattern,), self._wrap(code), self.__class__.__name__, method,
(pattern,),
self._wrap(code),
self.__class__.__name__,
) )
@ -855,8 +858,7 @@ class FederationVersionServlet(BaseFederationServlet):
class FederationGroupsProfileServlet(BaseFederationServlet): class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get/set the basic profile of a group on behalf of a user """Get/set the basic profile of a group on behalf of a user"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/profile" PATH = "/groups/(?P<group_id>[^/]*)/profile"
@ -895,8 +897,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user """Get the rooms in a group on behalf of a user"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms" PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@ -911,8 +912,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group """Add/remove room from group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@ -940,8 +940,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group """Update room config in group"""
"""
PATH = ( PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@ -961,8 +960,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user """Get the users in a group on behalf of a user"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/users" PATH = "/groups/(?P<group_id>[^/]*)/users"
@ -977,8 +975,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group """Get the users that have been invited to a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users" PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@ -995,8 +992,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group """Ask a group server to invite someone to the group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1013,8 +1009,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server """Accept an invitation from the group server"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@ -1028,8 +1023,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group """Attempt to join a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@ -1043,8 +1037,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group """Leave or kick a user from the group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1061,8 +1054,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user """A group server has invited a local user"""
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1076,8 +1068,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user """A group server has removed a local user"""
"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1093,8 +1084,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation """A group or user's server renews their attestation"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
@ -1156,8 +1146,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseFederationServlet): class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group """Get all categories for a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?" PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@ -1172,8 +1161,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseFederationServlet): class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group """Add/remove/get a category in a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@ -1218,8 +1206,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseFederationServlet): class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group """Get roles in a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?" PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@ -1234,8 +1221,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseFederationServlet): class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group """Add/remove/get a role in a group"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@ -1325,8 +1311,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group """Get roles in a group"""
"""
PATH = "/get_groups_publicised" PATH = "/get_groups_publicised"
@ -1339,8 +1324,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock """Sets whether a group is joinable without an invite or knock"""
"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"

View File

@ -61,8 +61,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning: class GroupAttestationSigning:
"""Creates and verifies group attestations. """Creates and verifies group attestations."""
"""
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -125,8 +124,7 @@ class GroupAttestationSigning:
class GroupAttestionRenewer: class GroupAttestionRenewer:
"""Responsible for sending and receiving attestation updates. """Responsible for sending and receiving attestation updates."""
"""
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -142,8 +140,7 @@ class GroupAttestionRenewer:
) )
async def on_renew_attestation(self, group_id, user_id, content): async def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation """When a remote updates an attestation"""
"""
attestation = content["attestation"] attestation = content["attestation"]
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
@ -161,8 +158,7 @@ class GroupAttestionRenewer:
return run_as_background_process("renew_attestations", self._renew_attestations) return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self): async def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations """Called periodically to check if we need to update any of our attestations"""
"""
now = self.clock.time_msec() now = self.clock.time_msec()

View File

@ -165,16 +165,14 @@ class GroupsServerWorkerHandler:
} }
async def get_group_categories(self, group_id, requester_user_id): async def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user) """Get all categories in a group (as seen by user)"""
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = await self.store.get_group_categories(group_id=group_id) categories = await self.store.get_group_categories(group_id=group_id)
return {"categories": categories} return {"categories": categories}
async def get_group_category(self, group_id, requester_user_id, category_id): async def get_group_category(self, group_id, requester_user_id, category_id):
"""Get a specific category in a group (as seen by user) """Get a specific category in a group (as seen by user)"""
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_category( res = await self.store.get_group_category(
@ -186,24 +184,21 @@ class GroupsServerWorkerHandler:
return res return res
async def get_group_roles(self, group_id, requester_user_id): async def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user) """Get all roles in a group (as seen by user)"""
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = await self.store.get_group_roles(group_id=group_id) roles = await self.store.get_group_roles(group_id=group_id)
return {"roles": roles} return {"roles": roles}
async def get_group_role(self, group_id, requester_user_id, role_id): async def get_group_role(self, group_id, requester_user_id, role_id):
"""Get a specific role in a group (as seen by user) """Get a specific role in a group (as seen by user)"""
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_role(group_id=group_id, role_id=role_id) res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
return res return res
async def get_group_profile(self, group_id, requester_user_id): async def get_group_profile(self, group_id, requester_user_id):
"""Get the group profile as seen by requester_user_id """Get the group profile as seen by requester_user_id"""
"""
await self.check_group_is_ours(group_id, requester_user_id) await self.check_group_is_ours(group_id, requester_user_id)
@ -350,8 +345,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_summary_room( async def update_group_summary_room(
self, group_id, requester_user_id, room_id, category_id, content self, group_id, requester_user_id, room_id, category_id, content
): ):
"""Add/update a room to the group summary """Add/update a room to the group summary"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -375,8 +369,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def delete_group_summary_room( async def delete_group_summary_room(
self, group_id, requester_user_id, room_id, category_id self, group_id, requester_user_id, room_id, category_id
): ):
"""Remove a room from the summary """Remove a room from the summary"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -409,8 +402,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_category( async def update_group_category(
self, group_id, requester_user_id, category_id, content self, group_id, requester_user_id, category_id, content
): ):
"""Add/Update a group category """Add/Update a group category"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -428,8 +420,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_category(self, group_id, requester_user_id, category_id): async def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category """Delete a group category"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -441,8 +432,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_role(self, group_id, requester_user_id, role_id, content): async def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group """Add/update a role in a group"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -458,8 +448,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def delete_group_role(self, group_id, requester_user_id, role_id): async def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group """Remove role from group"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -471,8 +460,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_summary_user( async def update_group_summary_user(
self, group_id, requester_user_id, user_id, role_id, content self, group_id, requester_user_id, user_id, role_id, content
): ):
"""Add/update a users entry in the group summary """Add/update a users entry in the group summary"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -494,8 +482,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def delete_group_summary_user( async def delete_group_summary_user(
self, group_id, requester_user_id, user_id, role_id self, group_id, requester_user_id, user_id, role_id
): ):
"""Remove a user from the group summary """Remove a user from the group summary"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -507,8 +494,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def update_group_profile(self, group_id, requester_user_id, content): async def update_group_profile(self, group_id, requester_user_id, content):
"""Update the group profile """Update the group profile"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -539,8 +525,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
await self.store.update_group_profile(group_id, profile) await self.store.update_group_profile(group_id, profile)
async def add_room_to_group(self, group_id, requester_user_id, room_id, content): async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
"""Add room to group """Add room to group"""
"""
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours( await self.check_group_is_ours(
@ -556,8 +541,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_room_in_group( async def update_room_in_group(
self, group_id, requester_user_id, room_id, config_key, content self, group_id, requester_user_id, room_id, config_key, content
): ):
"""Update room in group """Update room in group"""
"""
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours( await self.check_group_is_ours(
@ -576,8 +560,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def remove_room_from_group(self, group_id, requester_user_id, room_id): async def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group """Remove room from group"""
"""
await self.check_group_is_ours( await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
@ -587,8 +570,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {} return {}
async def invite_to_group(self, group_id, user_id, requester_user_id, content): async def invite_to_group(self, group_id, user_id, requester_user_id, content):
"""Invite user to group """Invite user to group"""
"""
group = await self.check_group_is_ours( group = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -724,8 +706,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation} return {"state": "join", "attestation": local_attestation}
async def knock(self, group_id, requester_user_id, content): async def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group """A user requests becoming a member of the group"""
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError() raise NotImplementedError()
@ -918,8 +899,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
def _parse_join_policy_from_contents(content): def _parse_join_policy_from_contents(content):
"""Given a content for a request, return the specified join policy or None """Given a content for a request, return the specified join policy or None"""
"""
join_policy_dict = content.get("m.join_policy") join_policy_dict = content.get("m.join_policy")
if join_policy_dict: if join_policy_dict:
@ -929,8 +909,7 @@ def _parse_join_policy_from_contents(content):
def _parse_join_policy_dict(join_policy_dict): def _parse_join_policy_dict(join_policy_dict):
"""Given a dict for the "m.join_policy" config return the join policy specified """Given a dict for the "m.join_policy" config return the join policy specified"""
"""
join_policy_type = join_policy_dict.get("type") join_policy_type = join_policy_dict.get("type")
if not join_policy_type: if not join_policy_type:
return "invite" return "invite"

View File

@ -203,13 +203,11 @@ class AdminHandler(BaseHandler):
class ExfiltrationWriter(metaclass=abc.ABCMeta): class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data. """Interface used to specify how to write exported data."""
"""
@abc.abstractmethod @abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None: def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room. """Write a batch of events for a room."""
"""
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod

View File

@ -290,7 +290,9 @@ class ApplicationServicesHandler:
if not interested: if not interested:
continue continue
presence_events, _ = await presence_source.get_new_events( presence_events, _ = await presence_source.get_new_events(
user=user, service=service, from_key=from_key, user=user,
service=service,
from_key=from_key,
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
events.extend( events.extend(

View File

@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier(
# Ensure the identifier has a type # Ensure the identifier has a type
if "type" not in identifier: if "type" not in identifier:
raise SynapseError( raise SynapseError(
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, 400,
"'identifier' dict has no key 'type'",
errcode=Codes.MISSING_PARAM,
) )
return identifier return identifier
@ -351,7 +353,11 @@ class AuthHandler(BaseHandler):
try: try:
result, params, session_id = await self.check_ui_auth( result, params, session_id = await self.check_ui_auth(
flows, request, request_body, description, get_new_session_data, flows,
request,
request_body,
description,
get_new_session_data,
) )
except LoginError: except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise). # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
@ -379,8 +385,7 @@ class AuthHandler(BaseHandler):
return params, session_id return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use """Get a list of the authentication types this user can use"""
"""
ui_auth_types = set() ui_auth_types = set()
@ -723,7 +728,9 @@ class AuthHandler(BaseHandler):
} }
def _auth_dict_for_flows( def _auth_dict_for_flows(
self, flows: List[List[str]], session_id: str, self,
flows: List[List[str]],
session_id: str,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
public_flows = [] public_flows = []
for f in flows: for f in flows:
@ -880,7 +887,9 @@ class AuthHandler(BaseHandler):
return self._supported_login_types return self._supported_login_types
async def validate_login( async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False, self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API """Authenticates the user for the /login API
@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler):
raise raise
async def _validate_userid_login( async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any], self,
username: str,
login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login """Helper for validate_login
@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler):
# is considered OK since the newest SSO attributes should be most valid. # is considered OK since the newest SSO attributes should be most valid.
if extra_attributes: if extra_attributes:
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
self._clock.time_msec(), extra_attributes, self._clock.time_msec(),
extra_attributes,
) )
# Create a login token # Create a login token
@ -1702,5 +1714,9 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
await maybe_awaitable( await maybe_awaitable(
g(user_id=user_id, device_id=device_id, access_token=access_token,) g(
user_id=user_id,
device_id=device_id,
access_token=access_token,
)
) )

View File

@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
class CasError(Exception): class CasError(Exception):
"""Used to catch errors when validating the CAS ticket. """Used to catch errors when validating the CAS ticket."""
"""
def __init__(self, error, error_description=None): def __init__(self, error, error_description=None):
self.error = error self.error = error
@ -100,7 +99,10 @@ class CasHandler:
Returns: Returns:
The URL to use as a "service" parameter. The URL to use as a "service" parameter.
""" """
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) return "%s?%s" % (
self._cas_service_url,
urllib.parse.urlencode(args),
)
async def _validate_ticket( async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str] self, ticket: str, service_args: Dict[str, str]
@ -296,7 +298,10 @@ class CasHandler:
# first check if we're doing a UIA # first check if we're doing a UIA
if session: if session:
return await self._sso_handler.complete_sso_ui_auth_request( return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, cas_response.username, session, request, self.idp_id,
cas_response.username,
session,
request,
) )
# otherwise, we're handling a login request. # otherwise, we're handling a login request.
@ -366,7 +371,8 @@ class CasHandler:
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self._hostname).to_string()
logger.debug( logger.debug(
"Looking for existing account based on mapped %s", user_id, "Looking for existing account based on mapped %s",
user_id,
) )
users = await self._store.get_users_by_id_case_insensitive(user_id) users = await self._store.get_users_by_id_case_insensitive(user_id)

View File

@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler):
run_as_background_process("user_parter_loop", self._user_parter_loop) run_as_background_process("user_parter_loop", self._user_parter_loop)
async def _user_parter_loop(self) -> None: async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms """Loop that parts deactivated users from rooms"""
"""
self._user_parter_running = True self._user_parter_running = True
logger.info("Starting user parter") logger.info("Starting user parter")
try: try:
@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler):
self._user_parter_running = False self._user_parter_running = False
async def _part_user(self, user_id: str) -> None: async def _part_user(self, user_id: str) -> None:
"""Causes the given user_id to leave all the rooms they're joined to """Causes the given user_id to leave all the rooms they're joined to"""
"""
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id)

View File

@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
device id of the dehydrated device device id of the dehydrated device
""" """
device_id = await self.check_device_registered( device_id = await self.check_device_registered(
user_id, None, initial_device_display_name, user_id,
None,
initial_device_display_name,
) )
old_device_id = await self.store.store_dehydrated_device( old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data user_id, device_id, device_data
@ -803,7 +805,8 @@ class DeviceListUpdater:
try: try:
# Try to resync the current user's devices list. # Try to resync the current user's devices list.
result = await self.user_device_resync( result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False, user_id=user_id,
mark_failed_as_stale=False,
) )
# user_device_resync only returns a result if it managed to # user_device_resync only returns a result if it managed to
@ -813,14 +816,17 @@ class DeviceListUpdater:
# self.store.update_remote_device_list_cache). # self.store.update_remote_device_list_cache).
if result: if result:
logger.debug( logger.debug(
"Successfully resynced the device list for %s", user_id, "Successfully resynced the device list for %s",
user_id,
) )
except Exception as e: except Exception as e:
# If there was an issue resyncing this user, e.g. if the remote # If there was an issue resyncing this user, e.g. if the remote
# server sent a malformed result, just log the error instead of # server sent a malformed result, just log the error instead of
# aborting all the subsequent resyncs. # aborting all the subsequent resyncs.
logger.debug( logger.debug(
"Could not resync the device list for %s: %s", user_id, e, "Could not resync the device list for %s: %s",
user_id,
e,
) )
finally: finally:
# Allow future calls to retry resyncinc out of sync device lists. # Allow future calls to retry resyncinc out of sync device lists.
@ -855,7 +861,9 @@ class DeviceListUpdater:
return None return None
except (RequestSendFailed, HttpResponseException) as e: except (RequestSendFailed, HttpResponseException) as e:
logger.warning( logger.warning(
"Failed to handle device list update for %s: %s", user_id, e, "Failed to handle device list update for %s: %s",
user_id,
e,
) )
if mark_failed_as_stale: if mark_failed_as_stale:
@ -931,7 +939,9 @@ class DeviceListUpdater:
# Handle cross-signing keys. # Handle cross-signing keys.
cross_signing_device_ids = await self.process_cross_signing_key_update( cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key, user_id,
master_key,
self_signing_key,
) )
device_ids = device_ids + cross_signing_device_ids device_ids = device_ids + cross_signing_device_ids

View File

@ -62,7 +62,8 @@ class DeviceMessageHandler:
) )
else: else:
hs.get_federation_registry().register_instances_for_edu( hs.get_federation_registry().register_instances_for_edu(
"m.direct_to_device", hs.config.worker.writers.to_device, "m.direct_to_device",
hs.config.worker.writers.to_device,
) )
# The handler to call when we think a user's device list might be out of # The handler to call when we think a user's device list might be out of
@ -73,8 +74,8 @@ class DeviceMessageHandler:
hs.get_device_handler().device_list_updater.user_device_resync hs.get_device_handler().device_list_updater.user_device_resync
) )
else: else:
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( self._user_device_resync = (
hs ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View File

@ -61,8 +61,8 @@ class E2eKeysHandler:
self._is_master = hs.config.worker_app is None self._is_master = hs.config.worker_app is None
if not self._is_master: if not self._is_master:
self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( self._user_device_resync_client = (
hs ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
else: else:
# Only register this edu handler on master as it requires writing # Only register this edu handler on master as it requires writing
@ -391,8 +391,7 @@ class E2eKeysHandler:
async def on_federation_query_client_keys( async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]] self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict: ) -> JsonDict:
""" Handle a device key query from a federated server """Handle a device key query from a federated server"""
"""
device_keys_query = query_body.get( device_keys_query = query_body.get(
"device_keys", {} "device_keys", {}
) # type: Dict[str, Optional[List[str]]] ) # type: Dict[str, Optional[List[str]]]
@ -1065,7 +1064,9 @@ class E2eKeysHandler:
return key, key_id, verify_key return key, key_id, verify_key
async def _retrieve_cross_signing_keys_for_remote_user( async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str, self,
user: UserID,
desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database """Queries cross-signing keys for a remote user and saves them to the database
@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
@attr.s(slots=True) @attr.s(slots=True)
class SignatureListItem: class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys. """An item in the signature list as used by upload_signatures_for_device_keys."""
"""
signing_key_id = attr.ib(type=str) signing_key_id = attr.ib(type=str)
target_user_id = attr.ib(type=str) target_user_id = attr.ib(type=str)
@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates) logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates: for master_key, self_signing_key in pending_updates:
new_device_ids = await device_list_updater.process_cross_signing_key_update( new_device_ids = (
user_id, master_key, self_signing_key, await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
) )
device_ids = device_ids + new_device_ids device_ids = device_ids + new_device_ids

View File

@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
room_id: Optional[str] = None, room_id: Optional[str] = None,
is_guest: bool = False, is_guest: bool = False,
) -> JsonDict: ) -> JsonDict:
"""Fetches the events stream for a given user. """Fetches the events stream for a given user."""
"""
if room_id: if room_id:
blocked = await self.store.is_room_blocked(room_id) blocked = await self.store.is_room_blocked(room_id)

View File

@ -150,11 +150,11 @@ class FederationHandler(BaseHandler):
) )
if hs.config.worker_app: if hs.config.worker_app:
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( self._user_device_resync = (
hs ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( self._maybe_store_room_on_outlier_membership = (
hs ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
) )
else: else:
self._device_list_updater = hs.get_device_handler().device_list_updater self._device_list_updater = hs.get_device_handler().device_list_updater
@ -368,7 +368,8 @@ class FederationHandler(BaseHandler):
# know about # know about
for p in prevs - seen: for p in prevs - seen:
logger.info( logger.info(
"Requesting state at missing prev_event %s", event_id, "Requesting state at missing prev_event %s",
event_id,
) )
with nested_logging_context(p): with nested_logging_context(p):
@ -388,13 +389,15 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store( state_map = (
await self._state_resolution_handler.resolve_events_with_store(
room_id, room_id,
room_version, room_version,
state_maps, state_maps,
event_map, event_map,
state_res_store=StateResolutionStore(self.store), state_res_store=StateResolutionStore(self.store),
) )
)
# We need to give _process_received_pdu the actual state events # We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now. # rather than event ids, so generate that now.
@ -687,7 +690,10 @@ class FederationHandler(BaseHandler):
return fetched_events return fetched_events
async def _process_received_pdu( async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
): ):
"""Called when we have a new pdu. We need to do auth checks and put it """Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler):
with nested_logging_context(event_id): with nested_logging_context(event_id):
try: try:
event = await self.federation_client.get_pdu( event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True, [destination],
event_id,
room_version,
outlier=True,
) )
if event is None: if event is None:
logger.warning( logger.warning(
"Server %s didn't return event %s", destination, event_id, "Server %s didn't return event %s",
destination,
event_id,
) )
return return
@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler):
if aid not in event_map if aid not in event_map
] ]
persisted_events = await self.store.get_events( persisted_events = await self.store.get_events(
auth_events, allow_rejected=True, auth_events,
allow_rejected=True,
) )
event_infos = [] event_infos = []
@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth)) event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events( await self._handle_new_events(
destination, room_id, event_infos, destination,
room_id,
event_infos,
) )
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev):
@ -1388,7 +1402,8 @@ class FederationHandler(BaseHandler):
# so we can rely on it now. # so we can rely on it now.
# #
await self.store.upsert_room_on_join( await self.store.upsert_room_on_join(
room_id=room_id, room_version=room_version_obj, room_id=room_id,
room_version=room_version_obj,
) )
max_stream_id = await self._persist_auth_tree( max_stream_id = await self._persist_auth_tree(
@ -1483,7 +1498,8 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room: if not is_in_room:
logger.info( logger.info(
"Got /make_join request for room %s we are no longer in", room_id, "Got /make_join request for room %s we are no longer in",
room_id,
) )
raise NotFoundError("Not an active room on this server") raise NotFoundError("Not an active room on this server")
@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler):
return None return None
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event."""
"""
event = await self.store.get_event(event_id, check_room_id=room_id) event = await self.store.get_event(event_id, check_room_id=room_id)
@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler):
return [] return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event."""
"""
event = await self.store.get_event(event_id, check_room_id=room_id) event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events: for e_id in missing_auth_events:
m_ev = await self.federation_client.get_pdu( m_ev = await self.federation_client.get_pdu(
[origin], e_id, room_version=room_version, outlier=True, timeout=10000, [origin],
e_id,
room_version=room_version,
outlier=True,
timeout=10000,
) )
if m_ev and m_ev.event_id == e_id: if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev event_map[e_id] = m_ev
@ -2160,7 +2178,9 @@ class FederationHandler(BaseHandler):
) )
logger.debug( logger.debug(
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, "Doing soft-fail check for %s: state %s",
event.event_id,
current_state_ids,
) )
# Now check if event pass auth against said current state # Now check if event pass auth against said current state

View File

@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler:
async def get_users_in_group( async def get_users_in_group(
self, group_id: str, requester_user_id: str self, group_id: str, requester_user_id: str
) -> JsonDict: ) -> JsonDict:
"""Get users in a group """Get users in a group"""
"""
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
return await self.groups_server_handler.get_users_in_group( return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id group_id, requester_user_id
@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def create_group( async def create_group(
self, group_id: str, user_id: str, content: JsonDict self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict: ) -> JsonDict:
"""Create a group """Create a group"""
"""
logger.info("Asking to create group with ID: %r", group_id) logger.info("Asking to create group with ID: %r", group_id)
@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def join_group( async def join_group(
self, group_id: str, user_id: str, content: JsonDict self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict: ) -> JsonDict:
"""Request to join a group """Request to join a group"""
"""
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
await self.groups_server_handler.join_group(group_id, user_id, content) await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None local_attestation = None
@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def accept_invite( async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict: ) -> JsonDict:
"""Accept an invite to a group """Accept an invite to a group"""
"""
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
await self.groups_server_handler.accept_invite(group_id, user_id, content) await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None local_attestation = None
@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def invite( async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict: ) -> JsonDict:
"""Invite a user to a group """Invite a user to a group"""
"""
content = {"requester_user_id": requester_user_id, "config": config} content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
res = await self.groups_server_handler.invite_to_group( res = await self.groups_server_handler.invite_to_group(
@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def on_invite( async def on_invite(
self, group_id: str, user_id: str, content: JsonDict self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict: ) -> JsonDict:
"""One of our users were invited to a group """One of our users were invited to a group"""
"""
# TODO: Support auto join and rejection # TODO: Support auto join and rejection
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def remove_user_from_group( async def remove_user_from_group(
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict: ) -> JsonDict:
"""Remove a user from a group """Remove a user from a group"""
"""
if user_id == requester_user_id: if user_id == requester_user_id:
token = await self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave" group_id, user_id, membership="leave"
@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def user_removed_from_group( async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict self, group_id: str, user_id: str, content: JsonDict
) -> None: ) -> None:
"""One of our users was removed/kicked from a group """One of our users was removed/kicked from a group"""
"""
# TODO: Check if user in group # TODO: Check if user in group
token = await self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave" group_id, user_id, membership="leave"

View File

@ -72,7 +72,10 @@ class IdentityHandler(BaseHandler):
) )
def ratelimit_request_token_requests( def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str, self,
request: SynapseRequest,
medium: str,
address: str,
): ):
"""Used to ratelimit requests to `/requestToken` by IP and address. """Used to ratelimit requests to `/requestToken` by IP and address.

View File

@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms( receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms, to_key=int(now_token.receipt_key), joined_rooms,
to_key=int(now_token.receipt_key),
) )
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id self.state_handler.get_current_state, event.room_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(None, event.stream_ordering,) room_end_token = RoomStreamToken(
None,
event.stream_ordering,
)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) )
@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
membership, membership,
member_event_id, member_event_id,
) = await self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True, room_id,
user_id,
allow_departed_users=True,
) )
is_peeking = member_event_id is None is_peeking = member_event_id is None

View File

@ -65,8 +65,7 @@ logger = logging.getLogger(__name__)
class MessageHandler: class MessageHandler:
"""Contains some read only APIs to get state about a room """Contains some read only APIs to get state about a room"""
"""
def __init__(self, hs): def __init__(self, hs):
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -88,7 +87,11 @@ class MessageHandler:
) )
async def get_room_data( async def get_room_data(
self, user_id: str, room_id: str, event_type: str, state_key: str, self,
user_id: str,
room_id: str,
event_type: str,
state_key: str,
) -> dict: ) -> dict:
"""Get data from a room. """Get data from a room.
@ -174,7 +177,10 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,)) raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client( visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False, self.storage,
user_id,
last_events,
filter_send_to_client=False,
) )
event = last_events[0] event = last_events[0]
@ -793,9 +799,10 @@ class EventCreationHandler:
""" """
if prev_event_ids is not None: if prev_event_ids is not None:
assert len(prev_event_ids) <= 10, ( assert (
"Attempting to create an event with %i prev_events" len(prev_event_ids) <= 10
% (len(prev_event_ids),) ), "Attempting to create an event with %i prev_events" % (
len(prev_event_ids),
) )
else: else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@ -821,7 +828,8 @@ class EventCreationHandler:
) )
if not third_party_result: if not third_party_result:
logger.info( logger.info(
"Event %s forbidden by third-party rules", event, "Event %s forbidden by third-party rules",
event,
) )
raise SynapseError( raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN 403, "This event is not allowed in this context", Codes.FORBIDDEN
@ -1316,7 +1324,11 @@ class EventCreationHandler:
# Since this is a dummy-event it is OK if it is sent by a # Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user. # shadow-banned user.
await self.handle_new_client_event( await self.handle_new_client_event(
requester, event, context, ratelimit=False, ignore_shadow_ban=True, requester,
event,
context,
ratelimit=False,
ignore_shadow_ban=True,
) )
return True return True
except AuthError: except AuthError:

View File

@ -73,8 +73,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler: class OidcHandler:
"""Handles requests related to the OpenID Connect login flow. """Handles requests related to the OpenID Connect login flow."""
"""
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
@ -216,8 +215,7 @@ class OidcHandler:
class OidcError(Exception): class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint """Used to catch errors when calling the token_endpoint"""
"""
def __init__(self, error, error_description=None): def __init__(self, error, error_description=None):
self.error = error self.error = error
@ -252,7 +250,9 @@ class OidcProvider:
self._scopes = provider.scopes self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth( self._client_auth = ClientAuth(
provider.client_id, provider.client_secret, provider.client_auth_method, provider.client_id,
provider.client_secret,
provider.client_auth_method,
) # type: ClientAuth ) # type: ClientAuth
self._client_auth_method = provider.client_auth_method self._client_auth_method = provider.client_auth_method
@ -509,7 +509,10 @@ class OidcProvider:
# We're not using the SimpleHttpClient util methods as we don't want to # We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself. # check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request( response = await self._http_client.request(
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
) )
# This is used in multiple error messages below # This is used in multiple error messages below
@ -966,7 +969,9 @@ class OidcSessionTokenGenerator:
A signed macaroon token with the session information. A signed macaroon token with the session information.
""" """
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key, location=self._server_name,
identifier="key",
key=self._macaroon_secret_key,
) )
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session") macaroon.add_first_party_caveat("type = session")

View File

@ -197,7 +197,8 @@ class PaginationHandler:
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = await self.store.get_room_event_before_stream_ordering( r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering, room_id,
stream_ordering,
) )
if not r: if not r:
logger.warning( logger.warning(
@ -223,7 +224,12 @@ class PaginationHandler:
# the background so that it's not blocking any other operation apart from # the background so that it's not blocking any other operation apart from
# other purges in the same room. # other purges in the same room.
run_as_background_process( run_as_background_process(
"_purge_history", self._purge_history, purge_id, room_id, token, True, "_purge_history",
self._purge_history,
purge_id,
room_id,
token,
True,
) )
def start_purge_history( def start_purge_history(
@ -389,7 +395,9 @@ class PaginationHandler:
) )
await self.hs.get_federation_handler().maybe_backfill( await self.hs.get_federation_handler().maybe_backfill(
room_id, curr_topo, limit=pagin_config.limit, room_id,
curr_topo,
limit=pagin_config.limit,
) )
to_room_key = None to_room_key = None

View File

@ -635,8 +635,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_process_last_updated_ms.pop(process_id, None) self.external_process_last_updated_ms.pop(process_id, None)
async def current_state_for_user(self, user_id): async def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user."""
"""
res = await self.current_state_for_users([user_id]) res = await self.current_state_for_users([user_id])
return res[user_id] return res[user_id]
@ -678,8 +677,7 @@ class PresenceHandler(BasePresenceHandler):
self.federation.send_presence(states) self.federation.send_presence(states)
async def incoming_presence(self, origin, content): async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server. """Called when we receive a `m.presence` EDU from a remote server."""
"""
if not self._presence_enabled: if not self._presence_enabled:
return return
@ -729,8 +727,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(updates) await self._update_states(updates)
async def set_state(self, target_user, state, ignore_status_msg=False): async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user. """Set the presence state of the user."""
"""
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
presence = state["presence"] presence = state["presence"]
@ -758,8 +755,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states([prev_state.copy_and_replace(**new_fields)]) await self._update_states([prev_state.copy_and_replace(**new_fields)])
async def is_visible(self, observed_user, observer_user): async def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence. """Returns whether a user can see another user's presence."""
"""
observer_room_ids = await self.store.get_rooms_for_user( observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string() observer_user.to_string()
) )
@ -953,8 +949,7 @@ class PresenceHandler(BasePresenceHandler):
def should_notify(old_state, new_state): def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties. """Decides if a presence state change should be sent to interested parties."""
"""
if old_state == new_state: if old_state == new_state:
return False return False

View File

@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler):
# This must be done by the target user himself. # This must be done by the target user himself.
if by_admin: if by_admin:
requester = create_requester( requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity, target_user,
authenticated_entity=requester.authenticated_entity,
) )
await self.store.set_profile_displayname( await self.store.set_profile_displayname(

View File

@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler):
) )
else: else:
hs.get_federation_registry().register_instances_for_edu( hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts, "m.receipt",
hs.config.worker.writers.receipts,
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS. """Called when we receive an EDU of type m.receipt from a remote HS."""
"""
receipts = [] receipts = []
for room_id, room_values in content.items(): for room_id, room_values in content.items():
for receipt_type, users in room_values.items(): for receipt_type, users in room_values.items():
@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts) await self._handle_new_receipts(receipts)
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier. """Takes a list of receipts, stores them and informs the notifier."""
"""
min_batch_id = None # type: Optional[int] min_batch_id = None # type: Optional[int]
max_batch_id = None # type: Optional[int] max_batch_id = None # type: Optional[int]

View File

@ -62,8 +62,8 @@ class RegistrationHandler(BaseHandler):
self._register_device_client = RegisterDeviceReplicationServlet.make_client( self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs hs
) )
self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( self._post_registration_client = (
hs ReplicationPostRegisterActionsServlet.make_client(hs)
) )
else: else:
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -189,12 +189,15 @@ class RegistrationHandler(BaseHandler):
self.check_registration_ratelimit(address) self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam( result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [], threepid,
localpart,
user_agent_ips or [],
) )
if result == RegistrationBehaviour.DENY: if result == RegistrationBehaviour.DENY:
logger.info( logger.info(
"Blocked registration of %r", localpart, "Blocked registration of %r",
localpart,
) )
# We return a 429 to make it not obvious that they've been # We return a 429 to make it not obvious that they've been
# denied. # denied.
@ -203,7 +206,8 @@ class RegistrationHandler(BaseHandler):
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned: if shadow_banned:
logger.info( logger.info(
"Shadow banning registration of %r", localpart, "Shadow banning registration of %r",
localpart,
) )
# do not check_auth_blocking if the call is coming through the Admin API # do not check_auth_blocking if the call is coming through the Admin API
@ -369,7 +373,9 @@ class RegistrationHandler(BaseHandler):
config["room_alias_name"] = room_alias.localpart config["room_alias_name"] = room_alias.localpart
info, _ = await room_creation_handler.create_room( info, _ = await room_creation_handler.create_room(
fake_requester, config=config, ratelimit=False, fake_requester,
config=config,
ratelimit=False,
) )
# If the room does not require an invite, but another user # If the room does not require an invite, but another user
@ -753,7 +759,10 @@ class RegistrationHandler(BaseHandler):
return return
await self._auth_handler.add_threepid( await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"], user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
) )
# And we add an email pusher for them by default, but only # And we add an email pusher for them by default, but only
@ -805,5 +814,8 @@ class RegistrationHandler(BaseHandler):
raise raise
await self._auth_handler.add_threepid( await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"], user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
) )

View File

@ -198,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
if r is None: if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,)) raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = await self._generate_room_id( new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version, creator_id=user_id,
is_public=r["is_public"],
room_version=new_version,
) )
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
@ -236,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
# now send the tombstone # now send the tombstone
await self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester=requester, event=tombstone_event, context=tombstone_context, requester=requester,
event=tombstone_event,
context=tombstone_context,
) )
old_room_state = await tombstone_context.get_current_state_ids() old_room_state = await tombstone_context.get_current_state_ids()
@ -257,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
# finally, shut down the PLs in the old room, and update them in the new # finally, shut down the PLs in the old room, and update them in the new
# room. # room.
await self._update_upgraded_room_pls( await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state, requester,
old_room_id,
new_room_id,
old_room_state,
) )
return new_room_id return new_room_id
@ -691,7 +698,9 @@ class RoomCreationHandler(BaseHandler):
is_public = visibility == "public" is_public = visibility == "public"
room_id = await self._generate_room_id( room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version, creator_id=user_id,
is_public=is_public,
room_version=room_version,
) )
# Check whether this visibility value is blocked by a third party module # Check whether this visibility value is blocked by a third party module
@ -884,7 +893,10 @@ class RoomCreationHandler(BaseHandler):
_, _,
last_stream_id, last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event( ) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False, ignore_shadow_ban=True, creator,
event,
ratelimit=False,
ignore_shadow_ban=True,
) )
return last_stream_id return last_stream_id
@ -984,7 +996,10 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id return last_sent_stream_id
async def _generate_room_id( async def _generate_room_id(
self, creator_id: str, is_public: bool, room_version: RoomVersion, self,
creator_id: str,
is_public: bool,
room_version: RoomVersion,
): ):
# autogen room IDs and try to create it. We may clash, so just # autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually. # try a few times till one goes through, giving up eventually.

View File

@ -191,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# do it up front for efficiency.) # do it up front for efficiency.)
if txn_id and requester.access_token_id: if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id( existing_event_id = await self.store.get_event_id_from_transaction_id(
room_id, requester.user.to_string(), requester.access_token_id, txn_id, room_id,
requester.user.to_string(),
requester.access_token_id,
txn_id,
) )
if existing_event_id: if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id) event_pos = await self.store.get_position_for_event(existing_event_id)
@ -238,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
) )
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit, requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
) )
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
@ -583,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# send the rejection to the inviter's HS (with fallback to # send the rejection to the inviter's HS (with fallback to
# local event) # local event)
return await self.remote_reject_invite( return await self.remote_reject_invite(
invite.event_id, txn_id, requester, content, invite.event_id,
txn_id,
requester,
content,
) )
# the inviter was on our server, but has now left. Carry on # the inviter was on our server, but has now left. Carry on
@ -1056,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user: UserID, user: UserID,
content: dict, content: dict,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join"""
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a # and if it is the only entry we'd like to return a 404 rather than a
# 500. # 500.
@ -1211,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)], requester,
event,
context,
extra_users=[UserID.from_string(target_user)],
) )
# we know it was persisted, so must have a stream ordering # we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering assert result_event.internal_metadata.stream_ordering
@ -1219,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering return result_event.event_id, result_event.internal_metadata.stream_ordering
async def _user_left_room(self, target: UserID, room_id: str) -> None: async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room"""
"""
user_left_room(self.distributor, target, room_id) user_left_room(self.distributor, target, room_id)
async def forget(self, user: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None:

View File

@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
user: UserID, user: UserID,
content: dict, content: dict,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join"""
"""
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"] return ret["event_id"], ret["stream_id"]
async def _user_left_room(self, target: UserID, room_id: str) -> None: async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room"""
"""
await self._notify_change_client( await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left" user_id=target.to_string(), room_id=room_id, change="left"
) )

View File

@ -121,7 +121,8 @@ class SamlHandler(BaseHandler):
now = self.clock.time_msec() now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData( self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id, creation_time=now,
ui_auth_session_id=ui_auth_session_id,
) )
for key, value in info["headers"]: for key, value in info["headers"]:
@ -450,7 +451,8 @@ class DefaultSamlMappingProvider:
mxid_source = saml_response.ava[self._mxid_source_attribute][0] mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError: except KeyError:
logger.warning( logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, "SAML2 response lacks a '%s' attestation",
self._mxid_source_attribute,
) )
raise SynapseError( raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,) 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)

View File

@ -327,7 +327,8 @@ class SsoHandler:
# Check if we already have a mapping for this user. # Check if we already have a mapping for this user.
previously_registered_user_id = await self._store.get_user_by_external_id( previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id, auth_provider_id,
remote_user_id,
) )
# A match was found, return the user ID. # A match was found, return the user ID.
@ -416,7 +417,8 @@ class SsoHandler:
with await self._mapping_lock.queue(auth_provider_id): with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user # first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id( user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id, auth_provider_id,
remote_user_id,
) )
# Check for grandfathering of users. # Check for grandfathering of users.
@ -461,7 +463,8 @@ class SsoHandler:
) )
async def _call_attribute_mapper( async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], self,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes: ) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid""" """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES): for i in range(self._MAP_USERNAME_RETRIES):
@ -632,7 +635,8 @@ class SsoHandler:
""" """
user_id = await self.get_sso_user_by_remote_user_id( user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id, auth_provider_id,
remote_user_id,
) )
user_id_to_verify = await self._auth_handler.get_session_data( user_id_to_verify = await self._auth_handler.get_session_data(
@ -671,7 +675,8 @@ class SsoHandler:
# render an error page. # render an error page.
html = self._bad_user_template.render( html = self._bad_user_template.render(
server_name=self._server_name, user_id_to_verify=user_id_to_verify, server_name=self._server_name,
user_id_to_verify=user_id_to_verify,
) )
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
@ -695,7 +700,9 @@ class SsoHandler:
raise SynapseError(400, "unknown session") raise SynapseError(400, "unknown session")
async def check_username_availability( async def check_username_availability(
self, localpart: str, session_id: str, self,
localpart: str,
session_id: str,
) -> bool: ) -> bool:
"""Handle an "is username available" callback check """Handle an "is username available" callback check
@ -833,7 +840,8 @@ class SsoHandler:
) )
attributes = UserAttributes( attributes = UserAttributes(
localpart=session.chosen_localpart, emails=session.emails_to_use, localpart=session.chosen_localpart,
emails=session.emails_to_use,
) )
if session.use_display_name: if session.use_display_name:

View File

@ -63,8 +63,7 @@ class StatsHandler:
self.clock.call_later(0, self.notify_new_event) self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self) -> None: def notify_new_event(self) -> None:
"""Called when there may be more deltas to process """Called when there may be more deltas to process"""
"""
if not self.stats_enabled or self._is_processing: if not self.stats_enabled or self._is_processing:
return return

View File

@ -339,8 +339,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
full_state: bool = False, full_state: bool = False,
) -> SyncResult: ) -> SyncResult:
"""Get the sync for client needed to match what the server has now. """Get the sync for client needed to match what the server has now."""
"""
return await self.generate_sync_result(sync_config, since_token, full_state) return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user: UserID) -> JsonDict: async def push_rules_for_user(self, user: UserID) -> JsonDict:
@ -820,9 +819,11 @@ class SyncHandler:
) )
elif batch.limited: elif batch.limited:
if batch: if batch:
state_at_timeline_start = await self.state_store.get_state_ids_for_event( state_at_timeline_start = (
await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id, state_filter=state_filter
) )
)
else: else:
# We can get here if the user has ignored the senders of all # We can get here if the user has ignored the senders of all
# the recent events. # the recent events.
@ -955,8 +956,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
full_state: bool = False, full_state: bool = False,
) -> SyncResult: ) -> SyncResult:
"""Generates a sync result. """Generates a sync result."""
"""
# NB: The now_token gets changed by some of the generate_sync_* methods, # NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability # this is due to some of the underlying streams not supporting the ability
# to query up to a given point. # to query up to a given point.
@ -1030,8 +1030,8 @@ class SyncHandler:
one_time_key_counts = await self.store.count_e2e_one_time_keys( one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( unused_fallback_key_types = (
user_id, device_id await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
) )
logger.debug("Fetching group data") logger.debug("Fetching group data")
@ -1176,9 +1176,11 @@ class SyncHandler:
# weren't in the previous sync *or* they left and rejoined. # weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users) users_that_have_changed.update(newly_joined_or_invited_users)
user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_signatures_changed = (
await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key user_id, since_token.device_list_key
) )
)
users_that_have_changed.update(user_signatures_changed) users_that_have_changed.update(user_signatures_changed)
# Now find users that we no longer track # Now find users that we no longer track
@ -1393,9 +1395,11 @@ class SyncHandler:
logger.debug("no-oping sync") logger.debug("no-oping sync")
return set(), set(), set(), set() return set(), set(), set(), set()
ignored_account_data = await self.store.get_global_account_data_by_type_for_user( ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
) )
)
# If there is ignored users account data and it matches the proper type, # If there is ignored users account data and it matches the proper type,
# then use it. # then use it.
@ -1499,8 +1503,7 @@ class SyncHandler:
async def _get_rooms_changed( async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges: ) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync. """Gets the the changes that have happened since the last sync."""
"""
user_id = sync_result_builder.sync_config.user.to_string() user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token now_token = sync_result_builder.now_token

View File

@ -61,7 +61,8 @@ class FollowerTypingHandler:
if hs.config.worker.writers.typing != hs.get_instance_name(): if hs.config.worker.writers.typing != hs.get_instance_name():
hs.get_federation_registry().register_instance_for_edu( hs.get_federation_registry().register_instance_for_edu(
"m.typing", hs.config.worker.writers.typing, "m.typing",
hs.config.worker.writers.typing,
) )
# map room IDs to serial numbers # map room IDs to serial numbers
@ -76,8 +77,7 @@ class FollowerTypingHandler:
self.clock.looping_call(self._handle_timeouts, 5000) self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self) -> None: def _reset(self) -> None:
"""Reset the typing handler's data caches. """Reset the typing handler's data caches."""
"""
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} self._room_serials = {}
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
@ -149,8 +149,7 @@ class FollowerTypingHandler:
def process_replication_rows( def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow] self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None: ) -> None:
"""Should be called whenever we receive updates for typing stream. """Should be called whenever we receive updates for typing stream."""
"""
if self._latest_room_serial > token: if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just # The master has gone backwards. To prevent inconsistent data, just

View File

@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler):
return results return results
def notify_new_event(self) -> None: def notify_new_event(self) -> None:
"""Called when there may be more deltas to process """Called when there may be more deltas to process"""
"""
if not self.update_user_directory: if not self.update_user_directory:
return return
@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
) )
async def handle_user_deactivated(self, user_id: str) -> None: async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated """Called when a user ID is deactivated"""
"""
# FIXME(#3714): We should probably do this in the same worker as all # FIXME(#3714): We should probably do this in the same worker as all
# the other changes. # the other changes.
await self.store.remove_from_user_dir(user_id) await self.store.remove_from_user_dir(user_id)
@ -172,8 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos) await self.store.update_user_directory_stream_pos(max_pos)
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process """Called with the state deltas to process"""
"""
for delta in deltas: for delta in deltas:
typ = delta["type"] typ = delta["type"]
state_key = delta["state_key"] state_key = delta["state_key"]

View File

@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
def get_request_user_agent(request: IRequest, default: str = "") -> str: def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default. """Return the last User-Agent header, or the given default."""
"""
# There could be raw utf-8 bytes in the User-Agent header. # There could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically # N.B. if you don't do this, the logger explodes cryptically

View File

@ -398,7 +398,8 @@ class SimpleHttpClient:
body_producer = None body_producer = None
if data is not None: if data is not None:
body_producer = QuieterFileBodyProducer( body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator, BytesIO(data),
cooperator=self._cooperator,
) )
request_deferred = treq.request( request_deferred = treq.request(
@ -413,7 +414,9 @@ class SimpleHttpClient:
# we use our own timeout mechanism rather than treq's as a workaround # we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534. # for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred( request_deferred = timeout_deferred(
request_deferred, 60, self.hs.get_reactor(), request_deferred,
60,
self.hs.get_reactor(),
) )
# turn timeouts into RequestTimedOutErrors # turn timeouts into RequestTimedOutErrors

View File

@ -195,8 +195,7 @@ class MatrixFederationAgent:
@implementer(IAgentEndpointFactory) @implementer(IAgentEndpointFactory)
class MatrixHostnameEndpointFactory: class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent. """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
"""
def __init__( def __init__(
self, self,
@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface """Implements IStreamClientEndpoint interface"""
"""
return run_in_background(self._do_connect, protocol_factory) return run_in_background(self._do_connect, protocol_factory)

View File

@ -81,8 +81,7 @@ class WellKnownLookupResult:
class WellKnownResolver: class WellKnownResolver:
"""Handles well-known lookups for matrix servers. """Handles well-known lookups for matrix servers."""
"""
def __init__( def __init__(
self, self,

View File

@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP # Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names # blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper( self.agent = BlacklistingAgentWrapper(
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, self.agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout _sec_timeout = self.default_timeout
body = await _handle_json_response( body = await _handle_json_response(
self.reactor, _sec_timeout, request, response, start_ms, self.reactor,
_sec_timeout,
request,
response,
start_ms,
) )
return body return body
@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
except BodyExceededMaxSize: except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,) msg = "Requested file is too large > %r bytes" % (max_size,)
logger.warning( logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg, "{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
) )
raise SynapseError(502, msg, Codes.TOO_LARGE) raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e: except Exception as e:

View File

@ -213,8 +213,7 @@ class RequestMetrics:
self.update_metrics() self.update_metrics()
def update_metrics(self): def update_metrics(self):
"""Updates the in flight metrics with values from this request. """Updates the in flight metrics with values from this request."""
"""
new_stats = self.start_context.get_resource_usage() new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats diff = new_stats - self._request_stats

View File

@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients. """Sends a JSON error response to clients."""
"""
if f.check(SynapseError): if f.check(SynapseError):
error_code = f.value.code error_code = f.value.code
@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass pass
else: else:
respond_with_json( respond_with_json(
request, error_code, error_dict, send_cors=True, request,
error_code,
error_dict,
send_cors=True,
) )
def return_html_error( def return_html_error(
f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], f: failure.Failure,
request: Request,
error_template: Union[str, jinja2.Template],
) -> None: ) -> None:
"""Sends an HTML error page corresponding to the given failure. """Sends an HTML error page corresponding to the given failure.
@ -189,8 +193,7 @@ ServletCallback = Callable[
class HttpServer(Protocol): class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server """Interface for registering callbacks on a HTTP server"""
"""
def register_paths( def register_paths(
self, self,
@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context self._extract_context = extract_context
def render(self, request): def render(self, request):
""" This gets called by twisted every time someone sends us a request. """This gets called by twisted every time someone sends us a request."""
"""
defer.ensureDeferred(self._async_render_wrapper(request)) defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET return NOT_DONE_YET
@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def _send_response( def _send_response(
self, request: SynapseRequest, code: int, response_object: Any, self,
request: SynapseRequest,
code: int,
response_object: Any,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def _send_error_response( def _send_error_response(
self, f: failure.Failure, request: SynapseRequest, self,
f: failure.Failure,
request: SynapseRequest,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
self.canonical_json = canonical_json self.canonical_json = canonical_json
def _send_response( def _send_response(
self, request: Request, code: int, response_object: Any, self,
request: Request,
code: int,
response_object: Any,
): ):
"""Implements _AsyncResource._send_response """Implements _AsyncResource._send_response"""
"""
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json( respond_with_json(
request, request,
@ -322,10 +331,11 @@ class DirectServeJsonResource(_AsyncResource):
) )
def _send_error_response( def _send_error_response(
self, f: failure.Failure, request: SynapseRequest, self,
f: failure.Failure,
request: SynapseRequest,
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response """Implements _AsyncResource._send_error_response"""
"""
return_json_error(f, request) return_json_error(f, request)
@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response( def _send_response(
self, request: SynapseRequest, code: int, response_object: Any, self,
request: SynapseRequest,
code: int,
response_object: Any,
): ):
"""Implements _AsyncResource._send_response """Implements _AsyncResource._send_response"""
"""
# We expect to get bytes for us to write # We expect to get bytes for us to write
assert isinstance(response_object, bytes) assert isinstance(response_object, bytes)
html_bytes = response_object html_bytes = response_object
@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
respond_with_html_bytes(request, 200, html_bytes) respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response( def _send_error_response(
self, f: failure.Failure, request: SynapseRequest, self,
f: failure.Failure,
request: SynapseRequest,
) -> None: ) -> None:
"""Implements _AsyncResource._send_error_response """Implements _AsyncResource._send_error_response"""
"""
return_html_error(f, request, self.ERROR_TEMPLATE) return_html_error(f, request, self.ERROR_TEMPLATE)
@ -534,7 +547,9 @@ class _ByteProducer:
min_chunk_size = 1024 min_chunk_size = 1024
def __init__( def __init__(
self, request: Request, iterator: Iterator[bytes], self,
request: Request,
iterator: Iterator[bytes],
): ):
self._request = request self._request = request
self._iterator = iterator self._iterator = iterator
@ -654,7 +669,10 @@ def respond_with_json(
def respond_with_json_bytes( def respond_with_json_bytes(
request: Request, code: int, json_bytes: bytes, send_cors: bool = False, request: Request,
code: int,
json_bytes: bytes,
send_cors: bool = False,
): ):
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.

View File

@ -249,8 +249,7 @@ class SynapseRequest(Request):
) )
def _finished_processing(self): def _finished_processing(self):
"""Log the completion of this request and update the metrics """Log the completion of this request and update the metrics"""
"""
assert self.logcontext is not None assert self.logcontext is not None
usage = self.logcontext.get_resource_usage() usage = self.logcontext.get_resource_usage()
@ -276,7 +275,8 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both. # authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity: if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format( authenticated_entity = "{},{}".format(
authenticated_entity, self.requester.user.to_string(), authenticated_entity,
self.requester.user.to_string(),
) )
elif self.requester is not None: elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information # This shouldn't happen, but we log it so we don't lose information
@ -322,8 +322,7 @@ class SynapseRequest(Request):
logger.warning("Failed to stop metrics: %r", e) logger.warning("Failed to stop metrics: %r", e)
def _should_log_request(self) -> bool: def _should_log_request(self) -> bool:
"""Whether we should log at INFO that we processed the request. """Whether we should log at INFO that we processed the request."""
"""
if self.path == b"/health": if self.path == b"/health":
return False return False

View File

@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it. # Make a new producer and start it.
self._producer = LogProducer( self._producer = LogProducer(
buffer=self._buffer, transport=result.transport, format=self.format, buffer=self._buffer,
transport=result.transport,
format=self.format,
) )
result.transport.registerProducer(self._producer, True) result.transport.registerProducer(self._producer, True)
self._producer.resumeProducing() self._producer.resumeProducing()

View File

@ -60,7 +60,10 @@ def parse_drain_configs(
) )
# Either use the default formatter or the tersejson one. # Either use the default formatter or the tersejson one.
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): if logging_type in (
DrainType.CONSOLE_JSON,
DrainType.FILE_JSON,
):
formatter = "json" # type: Optional[str] formatter = "json" # type: Optional[str]
elif logging_type in ( elif logging_type in (
DrainType.CONSOLE_JSON_TERSE, DrainType.CONSOLE_JSON_TERSE,
@ -131,7 +134,9 @@ def parse_drain_configs(
) )
def setup_structured_logging(log_config: dict,) -> dict: def setup_structured_logging(
log_config: dict,
) -> dict:
""" """
Convert a legacy structured logging configuration (from Synapse < v1.23.0) Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers. to one compatible with the new standard library handlers.

View File

@ -338,7 +338,10 @@ class LoggingContext:
if self.previous_context != old_context: if self.previous_context != old_context:
logcontext_error( logcontext_error(
"Expected previous context %r, found %r" "Expected previous context %r, found %r"
% (self.previous_context, old_context,) % (
self.previous_context,
old_context,
)
) )
return self return self
@ -585,7 +588,10 @@ class PreserveLoggingContext:
else: else:
logcontext_error( logcontext_error(
"Expected logging context %s but found %s" "Expected logging context %s but found %s"
% (self._new_context, context,) % (
self._new_context,
context,
)
) )

View File

@ -238,8 +238,7 @@ try:
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class _WrappedRustReporter: class _WrappedRustReporter:
"""Wrap the reporter to ensure `report_span` never throws. """Wrap the reporter to ensure `report_span` never throws."""
"""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"): def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer """Set the whitelists and initialise the JaegerClient tracer"""
"""
global opentracing global opentracing
if not hs.config.opentracer_enabled: if not hs.config.opentracer_enabled:
# We don't have a tracer # We don't have a tracer

View File

@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
def log_function(f): def log_function(f):
""" Function decorator that logs every call to that function. """Function decorator that logs every call to that function."""
"""
func_name = f.__name__ func_name = f.__name__
@wraps(f) @wraps(f)

View File

@ -155,8 +155,7 @@ class InFlightGauge:
self._registrations.setdefault(key, set()).add(callback) self._registrations.setdefault(key, set()).add(callback)
def unregister(self, key, callback): def unregister(self, key, callback):
"""Registers that we've exited a block with labels `key`. """Registers that we've exited a block with labels `key`."""
"""
with self._lock: with self._lock:
self._registrations.setdefault(key, set()).discard(callback) self._registrations.setdefault(key, set()).discard(callback)
@ -402,7 +401,9 @@ class PyPyGCStats:
# Total time spent in GC: 0.073 # s.total_gc_time # Total time spent in GC: 0.073 # s.total_gc_time
pypy_gc_time = CounterMetricFamily( pypy_gc_time = CounterMetricFamily(
"pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], "pypy_gc_time_seconds_total",
"Total time spent in PyPy GC",
labels=[],
) )
pypy_gc_time.add_metric([], s.total_gc_time / 1000) pypy_gc_time.add_metric([], s.total_gc_time / 1000)
yield pypy_gc_time yield pypy_gc_time

View File

@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
return await maybe_awaitable(func(*args, **kwargs)) return await maybe_awaitable(func(*args, **kwargs))
except Exception: except Exception:
logger.exception( logger.exception(
"Background process '%s' threw an exception", desc, "Background process '%s' threw an exception",
desc,
) )
finally: finally:
_background_process_in_flight_count.labels(desc).dec() _background_process_in_flight_count.labels(desc).dec()
@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource._RUsage]"): def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again). """Log context has started running (again)."""
"""
super().start(rusage) super().start(rusage)
@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.add(self._proc) _background_processes_active_since_last_scrape.add(self._proc)
def __exit__(self, type, value, traceback) -> None: def __exit__(self, type, value, traceback) -> None:
"""Log context has finished. """Log context has finished."""
"""
super().__exit__(type, value, traceback) super().__exit__(type, value, traceback)

View File

@ -275,7 +275,9 @@ class ModuleApi:
redirect them directly if whitelisted). redirect them directly if whitelisted).
""" """
self._auth_handler._complete_sso_login( self._auth_handler._complete_sso_login(
registered_user_id, request, client_redirect_url, registered_user_id,
request,
client_redirect_url,
) )
async def complete_sso_login_async( async def complete_sso_login_async(
@ -352,7 +354,10 @@ class ModuleApi:
event, event,
_, _,
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event( ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
requester, event_dict, ratelimit=False, ignore_shadow_ban=True, requester,
event_dict,
ratelimit=False,
ignore_shadow_ban=True,
) )
return event return event

View File

@ -119,7 +119,10 @@ class _NotifierUserStream:
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify( def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, self,
stream_key: str,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
): ):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
@ -265,8 +268,7 @@ class Notifier:
max_room_stream_token: RoomStreamToken, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Collection[UserID] = [],
): ):
"""Unwraps event and calls `on_new_room_event_args`. """Unwraps event and calls `on_new_room_event_args`."""
"""
self.on_new_room_event_args( self.on_new_room_event_args(
event_pos=event_pos, event_pos=event_pos,
room_id=event.room_id, room_id=event.room_id,
@ -341,7 +343,10 @@ class Notifier:
if users or rooms: if users or rooms:
self.on_new_event( self.on_new_event(
"room_key", max_room_stream_token, users=users, rooms=rooms, "room_key",
max_room_stream_token,
users=users,
rooms=rooms,
) )
self._on_updated_room_token(max_room_stream_token) self._on_updated_room_token(max_room_stream_token)
@ -418,7 +423,9 @@ class Notifier:
# Notify appservices # Notify appservices
self._notify_app_services_ephemeral( self._notify_app_services_ephemeral(
stream_key, new_token, users, stream_key,
new_token,
users,
) )
def on_new_replication_data(self) -> None: def on_new_replication_data(self) -> None:
@ -651,8 +658,7 @@ class Notifier:
cb() cb()
def notify_remote_server_up(self, server: str): def notify_remote_server_up(self, server: str):
"""Notify any replication that a remote server has come back up """Notify any replication that a remote server has come back up"""
"""
# We call federation_sender directly rather than registering as a # We call federation_sender directly rather than registering as a
# callback as a) we already have a reference to it and b) it introduces # callback as a) we already have a reference to it and b) it introduces
# circular dependencies. # circular dependencies.

View File

@ -144,8 +144,7 @@ class BulkPushRuleEvaluator:
@lru_cache() @lru_cache()
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id """Get the current RulesForRoom object for the given room id"""
"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be # before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache) # a race if invalidate_all gets called (which assumes its in the cache)
@ -252,7 +251,9 @@ class BulkPushRuleEvaluator:
# notified for this event. (This will then get handled when we persist # notified for this event. (This will then get handled when we persist
# the event) # the event)
await self.store.add_push_actions_to_staging( await self.store.add_push_actions_to_staging(
event.event_id, actions_by_user, count_as_unread, event.event_id,
actions_by_user,
count_as_unread,
) )

View File

@ -116,8 +116,7 @@ class EmailPusher(Pusher):
self._is_processing = True self._is_processing = True
def _resume_processing(self) -> None: def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing. """Used by tests to resume processing of events after pausing."""
"""
assert self._is_processing assert self._is_processing
self._is_processing = False self._is_processing = False
self._start_processing() self._start_processing()
@ -157,9 +156,11 @@ class EmailPusher(Pusher):
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( unprocessed = (
await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering self.user_id, start, self.max_stream_ordering
) )
)
soonest_due_at = None # type: Optional[int] soonest_due_at = None # type: Optional[int]
@ -222,13 +223,15 @@ class EmailPusher(Pusher):
self, last_stream_ordering: int self, last_stream_ordering: int
) -> None: ) -> None:
self.last_stream_ordering = last_stream_ordering self.last_stream_ordering = last_stream_ordering
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
self.email, self.email,
self.user_id, self.user_id,
last_stream_ordering, last_stream_ordering,
self.clock.time_msec(), self.clock.time_msec(),
) )
)
if not pusher_still_exists: if not pusher_still_exists:
# The pusher has been deleted while we were processing, so # The pusher has been deleted while we were processing, so
# lets just stop and return. # lets just stop and return.
@ -298,7 +301,8 @@ class EmailPusher(Pusher):
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
) )
self.throttle_params[room_id] = ThrottleParams( self.throttle_params[room_id] = ThrottleParams(
self.clock.time_msec(), new_throttle_ms, self.clock.time_msec(),
new_throttle_ms,
) )
assert self.pusher_id is not None assert self.pusher_id is not None
await self.store.set_throttle_params( await self.store.set_throttle_params(

View File

@ -176,9 +176,11 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
run once per pusher. run once per pusher.
""" """
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( unprocessed = (
await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
)
logger.info( logger.info(
"Processing %i unprocessed push actions for %s starting at " "Processing %i unprocessed push actions for %s starting at "
@ -204,13 +206,15 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc() http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
self.last_stream_ordering, self.last_stream_ordering,
self.clock.time_msec(), self.clock.time_msec(),
) )
)
if not pusher_still_exists: if not pusher_still_exists:
# The pusher has been deleted while we were processing, so # The pusher has been deleted while we were processing, so
# lets just stop and return. # lets just stop and return.
@ -290,7 +294,8 @@ class HttpPusher(Pusher):
# for sanity, we only remove the pushkey if it # for sanity, we only remove the pushkey if it
# was the one we actually sent... # was the one we actually sent...
logger.warning( logger.warning(
("Ignoring rejected pushkey %s because we didn't send it"), pk, ("Ignoring rejected pushkey %s because we didn't send it"),
pk,
) )
else: else:
logger.info("Pushkey %s was rejected: removing", pk) logger.info("Pushkey %s was rejected: removing", pk)

View File

@ -78,8 +78,7 @@ class PusherPool:
self.pushers = {} # type: Dict[str, Dict[str, Pusher]] self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self) -> None: def start(self) -> None:
"""Starts the pushers off in a background process. """Starts the pushers off in a background process."""
"""
if not self._should_start_pushers: if not self._should_start_pushers:
logger.info("Not starting pushers because they are disabled in the config") logger.info("Not starting pushers because they are disabled in the config")
return return
@ -297,8 +296,7 @@ class PusherPool:
return pusher return pusher
async def _start_pushers(self) -> None: async def _start_pushers(self) -> None:
"""Start all the pushers """Start all the pushers"""
"""
pushers = await self.store.get_all_pushers() pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the # Stagger starting up the pushers so we don't completely drown the
@ -335,7 +333,8 @@ class PusherPool:
return None return None
except Exception: except Exception:
logger.exception( logger.exception(
"Couldn't start pusher id %i: caught Exception", pusher_config.id, "Couldn't start pusher id %i: caught Exception",
pusher_config.id,
) )
return None return None

View File

@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths( http_server.register_paths(
method, [pattern], self._check_auth_and_handle, self.__class__.__name__, method,
[pattern],
self._check_auth_and_handle,
self.__class__.__name__,
) )
def _check_auth_and_handle(self, request, **kwargs): def _check_auth_and_handle(self, request, **kwargs):

View File

@ -175,7 +175,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request(self, request, user_id, room_id, tag): async def _handle_request(self, request, user_id, room_id, tag):
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) max_stream_id = await self.handler.remove_tag_from_room(
user_id,
room_id,
tag,
)
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}

View File

@ -160,7 +160,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# hopefully we're now on the master, so this won't recurse! # hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite( event_id, stream_id = await self.member_handler.remote_reject_invite(
invite_event_id, txn_id, requester, event_content, invite_event_id,
txn_id,
requester,
event_content,
) )
return 200, {"event_id": event_id, "stream_id": stream_id} return 200, {"event_id": event_id, "stream_id": stream_id}

View File

@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
class ReplicationRegisterServlet(ReplicationEndpoint): class ReplicationRegisterServlet(ReplicationEndpoint):
"""Register a new user """Register a new user"""
"""
NAME = "register_user" NAME = "register_user"
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)
@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
"""Run any post registration actions """Run any post registration actions"""
"""
NAME = "post_register" NAME = "post_register"
PATH_ARGS = ("user_id",) PATH_ARGS = ("user_id",)

View File

@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand): class PingCommand(_SimpleCommand):
"""Sent by either side as a keep alive. The data is arbitrary (often timestamp) """Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
"""
NAME = "PING" NAME = "PING"

View File

@ -60,8 +60,7 @@ class ExternalCache:
return self._redis_connection is not None return self._redis_connection is not None
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
"""Add the key/value to the named cache, with the expiry time given. """Add the key/value to the named cache, with the expiry time given."""
"""
if self._redis_connection is None: if self._redis_connection is None:
return return
@ -76,13 +75,14 @@ class ExternalCache:
return await make_deferred_yieldable( return await make_deferred_yieldable(
self._redis_connection.set( self._redis_connection.set(
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, self._get_redis_key(cache_name, key),
encoded_value,
pexpire=expiry_ms,
) )
) )
async def get(self, cache_name: str, key: str) -> Optional[Any]: async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache. """Look up a key/value in the named cache."""
"""
if self._redis_connection is None: if self._redis_connection is None:
return None return None

View File

@ -303,7 +303,9 @@ class ReplicationCommandHandler:
hs, outbound_redis_connection hs, outbound_redis_connection
) )
hs.get_reactor().connectTCP( hs.get_reactor().connectTCP(
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, hs.config.redis.redis_host,
hs.config.redis.redis_port,
self._factory,
) )
else: else:
client_name = hs.get_instance_name() client_name = hs.get_instance_name()
@ -313,13 +315,11 @@ class ReplicationCommandHandler:
hs.get_reactor().connectTCP(host, port, self._factory) hs.get_reactor().connectTCP(host, port, self._factory)
def get_streams(self) -> Dict[str, Stream]: def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams. """Get a map from stream name to all streams."""
"""
return self._streams return self._streams
def get_streams_to_replicate(self) -> List[Stream]: def get_streams_to_replicate(self) -> List[Stream]:
"""Get a list of streams that this instances replicates. """Get a list of streams that this instances replicates."""
"""
return self._streams_to_replicate return self._streams_to_replicate
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
@ -340,7 +340,10 @@ class ReplicationCommandHandler:
current_token = stream.current_token(self._instance_name) current_token = stream.current_token(self._instance_name)
self.send_command( self.send_command(
PositionCommand( PositionCommand(
stream.NAME, self._instance_name, current_token, current_token, stream.NAME,
self._instance_name,
current_token,
current_token,
) )
) )
@ -592,8 +595,7 @@ class ReplicationCommandHandler:
self.send_command(cmd, ignore_conn=conn) self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection): def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection. """Called when we have a new connection."""
"""
self._connections.append(connection) self._connections.append(connection)
# If we are connected to replication as a client (rather than a server) # If we are connected to replication as a client (rather than a server)
@ -620,8 +622,7 @@ class ReplicationCommandHandler:
) )
def lost_connection(self, connection: AbstractConnection): def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost. """Called when a connection is closed/lost."""
"""
# we no longer need _streams_by_connection for this connection. # we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None) streams = self._streams_by_connection.pop(connection, None)
if streams: if streams:
@ -678,15 +679,13 @@ class ReplicationCommandHandler:
def send_user_sync( def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
): ):
"""Poke the master that a user has started/stopped syncing. """Poke the master that a user has started/stopped syncing."""
"""
self.send_command( self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
) )
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user """Poke the master to remove a pusher for a user"""
"""
cmd = RemovePusherCommand(app_id, push_key, user_id) cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd) self.send_command(cmd)
@ -699,8 +698,7 @@ class ReplicationCommandHandler:
device_id: str, device_id: str,
last_seen: int, last_seen: int,
): ):
"""Tell the master that the user made a request. """Tell the master that the user made a request."""
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd) self.send_command(cmd)

View File

@ -222,8 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.send_error("ping timeout") self.send_error("ping timeout")
def lineReceived(self, line: bytes): def lineReceived(self, line: bytes):
"""Called when we've received a line """Called when we've received a line"""
"""
with PreserveLoggingContext(self._logging_context): with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_line(line) self._parse_and_dispatch_line(line)
@ -299,8 +298,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.on_connection_closed() self.on_connection_closed()
def send_error(self, error_string, *args): def send_error(self, error_string, *args):
"""Send an error to remote and close the connection. """Send an error to remote and close the connection."""
"""
self.send_command(ErrorCommand(error_string % args)) self.send_command(ErrorCommand(error_string % args))
self.close() self.close()
@ -341,8 +339,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_sent_command = self.clock.time_msec() self.last_sent_command = self.clock.time_msec()
def _queue_command(self, cmd): def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again. """Queue the command until the connection is ready to write to again."""
"""
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd) self.pending_commands.append(cmd)
@ -355,8 +352,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.close() self.close()
def _send_pending_commands(self): def _send_pending_commands(self):
"""Send any queued commandes """Send any queued commandes"""
"""
pending = self.pending_commands pending = self.pending_commands
self.pending_commands = [] self.pending_commands = []
for cmd in pending: for cmd in pending:
@ -380,8 +376,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.PAUSED self.state = ConnectionStates.PAUSED
def resumeProducing(self): def resumeProducing(self):
"""The remote has caught up after we started buffering! """The remote has caught up after we started buffering!"""
"""
logger.info("[%s] Resume producing", self.id()) logger.info("[%s] Resume producing", self.id())
self.state = ConnectionStates.ESTABLISHED self.state = ConnectionStates.ESTABLISHED
self._send_pending_commands() self._send_pending_commands()
@ -440,8 +435,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return "%s-%s" % (self.name, self.conn_id) return "%s-%s" % (self.name, self.conn_id)
def lineLengthExceeded(self, line): def lineLengthExceeded(self, line):
"""Called when we receive a line that is above the maximum line length """Called when we receive a line that is above the maximum line length"""
"""
self.send_error("Line length exceeded") self.send_error("Line length exceeded")
@ -495,21 +489,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote") self.send_error("Wrong remote")
def replicate(self): def replicate(self):
"""Send the subscription request to the server """Send the subscription request to the server"""
"""
logger.info("[%s] Subscribing to replication streams", self.id()) logger.info("[%s] Subscribing to replication streams", self.id())
self.send_command(ReplicateCommand()) self.send_command(ReplicateCommand())
class AbstractConnection(abc.ABC): class AbstractConnection(abc.ABC):
"""An interface for replication connections. """An interface for replication connections."""
"""
@abc.abstractmethod @abc.abstractmethod
def send_command(self, cmd: Command): def send_command(self, cmd: Command):
"""Send the command down the connection """Send the command down the connection"""
"""
pass pass

Some files were not shown because too many files have changed in this diff Show More