Make `synapse._scripts` pass typechecks (#12421)

This commit is contained in:
David Robertson 2022-04-08 15:00:12 +01:00 committed by GitHub
parent dd5cc37aa4
commit 0cd182f296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 43 deletions

1
changelog.d/12421.misc Normal file
View File

@ -0,0 +1 @@
Make `synapse._scripts` pass type checks.

View File

@ -28,11 +28,6 @@ exclude = (?x)
|scripts-dev/federation_client.py
|scripts-dev/release.py
|synapse/_scripts/export_signing_key.py
|synapse/_scripts/move_remote_media_to_new_store.py
|synapse/_scripts/synapse_port_db.py
|synapse/_scripts/update_synapse_database.py
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py

View File

@ -17,8 +17,8 @@ import sys
import time
from typing import Optional
import nacl.signing
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
from signedjson.types import VerifyKey
def exit(status: int = 0, message: Optional[str] = None):
@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
sys.exit(status)
def format_plain(public_key: nacl.signing.VerifyKey):
def format_plain(public_key: VerifyKey):
print(
"%s:%s %s"
% (
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
)
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
def format_for_config(public_key: VerifyKey, expiry_ts: int):
print(
' "%s:%s": { key: "%s", expired_ts: %i }'
% (

View File

@ -109,10 +109,9 @@ if __name__ == "__main__":
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)
main(args.src_repo, args.dest_repo)

View File

@ -21,12 +21,13 @@ import logging
import sys
import time
import traceback
from typing import Dict, Iterable, Optional, Set
from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import Clock
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("synapse_port_db")
@ -159,12 +164,14 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
end_error = None # type: Optional[str]
end_error: Optional[str] = None
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info = None
end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None
class Store(
@ -236,9 +243,12 @@ class MockHomeserver:
return "master"
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class Porter:
def __init__(self, sqlite_config, progress, batch_size, hs_config):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
async def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
@ -323,7 +333,7 @@ class Porter(object):
"""
txn.execute(sql)
results = {}
results: Dict[str, Set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@ -540,7 +550,8 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
db_conn.commit()
return store
@ -724,7 +735,9 @@ class Porter(object):
except Exception as e:
global end_error_exec_info
end_error = str(e)
end_error_exec_info = sys.exc_info()
# Type safety: we're in an exception handler, so the exc_info() tuple
# will not be (None, None, None).
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
logger.exception("")
finally:
reactor.stop()
@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.last_update = 0.0
self.finished = False
@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
left_margin = 5
middle_space = 1
items = self.tables.items()
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
@ -1179,15 +1191,11 @@ def main():
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
filename="port-synapse.log" if args.curses else None,
)
sqlite_config = {
"name": "sqlite3",
@ -1218,6 +1226,7 @@ def main():
config.parse_config_dict(hs_config, "", "")
def start(stdscr=None):
progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
else:

View File

@ -16,22 +16,27 @@
import argparse
import logging
import sys
from typing import cast
import yaml
from matrix_common.versionstring import get_distribution_version_string
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as reactor_
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.types import ISynapseReactor
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
@ -85,12 +90,10 @@ def main():
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
}
logging.basicConfig(**logging_config)
logging.basicConfig(
level=logging.DEBUG if args.v else logging.INFO,
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
)
# Load, process and sanity-check the config.
hs_config = yaml.safe_load(args.database_config)