Make `synapse._scripts` pass typechecks (#12421)
This commit is contained in:
parent
dd5cc37aa4
commit
0cd182f296
|
@ -0,0 +1 @@
|
|||
Make `synapse._scripts` pass type checks.
|
5
mypy.ini
5
mypy.ini
|
@ -28,11 +28,6 @@ exclude = (?x)
|
|||
|scripts-dev/federation_client.py
|
||||
|scripts-dev/release.py
|
||||
|
||||
|synapse/_scripts/export_signing_key.py
|
||||
|synapse/_scripts/move_remote_media_to_new_store.py
|
||||
|synapse/_scripts/synapse_port_db.py
|
||||
|synapse/_scripts/update_synapse_database.py
|
||||
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|
|
|
@ -17,8 +17,8 @@ import sys
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
import nacl.signing
|
||||
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
||||
from signedjson.types import VerifyKey
|
||||
|
||||
|
||||
def exit(status: int = 0, message: Optional[str] = None):
|
||||
|
@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
|
|||
sys.exit(status)
|
||||
|
||||
|
||||
def format_plain(public_key: nacl.signing.VerifyKey):
|
||||
def format_plain(public_key: VerifyKey):
|
||||
print(
|
||||
"%s:%s %s"
|
||||
% (
|
||||
|
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
|
|||
)
|
||||
|
||||
|
||||
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
|
||||
def format_for_config(public_key: VerifyKey, expiry_ts: int):
|
||||
print(
|
||||
' "%s:%s": { key: "%s", expired_ts: %i }'
|
||||
% (
|
||||
|
|
|
@ -109,10 +109,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("dest_repo", help="Path to source content repo")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
main(args.src_repo, args.dest_repo)
|
||||
|
|
|
@ -21,12 +21,13 @@ import logging
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Iterable, Optional, Set
|
||||
from types import TracebackType
|
||||
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
|
||||
|
||||
import yaml
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor as reactor_
|
||||
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
|
|||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
|
||||
# Cast safety: Twisted does some naughty magic which replaces the
|
||||
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||
reactor = cast(ISynapseReactor, reactor_)
|
||||
logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
|
||||
|
@ -159,12 +164,14 @@ IGNORED_TABLES = {
|
|||
|
||||
# Error returned by the run function. Used at the top-level part of the script to
|
||||
# handle errors and return codes.
|
||||
end_error = None # type: Optional[str]
|
||||
end_error: Optional[str] = None
|
||||
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
||||
# will show only the error message without the stacktrace, if exec_info is defined but
|
||||
# not the error then the script will show nothing outside of what's printed in the run
|
||||
# function. If both are defined, the script will print both the error and the stacktrace.
|
||||
end_error_exec_info = None
|
||||
end_error_exec_info: Optional[
|
||||
Tuple[Type[BaseException], BaseException, TracebackType]
|
||||
] = None
|
||||
|
||||
|
||||
class Store(
|
||||
|
@ -236,9 +243,12 @@ class MockHomeserver:
|
|||
return "master"
|
||||
|
||||
|
||||
class Porter(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
class Porter:
|
||||
def __init__(self, sqlite_config, progress, batch_size, hs_config):
|
||||
self.sqlite_config = sqlite_config
|
||||
self.progress = progress
|
||||
self.batch_size = batch_size
|
||||
self.hs_config = hs_config
|
||||
|
||||
async def setup_table(self, table):
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
|
@ -323,7 +333,7 @@ class Porter(object):
|
|||
"""
|
||||
txn.execute(sql)
|
||||
|
||||
results = {}
|
||||
results: Dict[str, Set[str]] = {}
|
||||
for table, foreign_table in txn:
|
||||
results.setdefault(table, set()).add(foreign_table)
|
||||
return results
|
||||
|
@ -540,7 +550,8 @@ class Porter(object):
|
|||
db_conn, allow_outdated_version=allow_outdated_version
|
||||
)
|
||||
prepare_database(db_conn, engine, config=self.hs_config)
|
||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
|
||||
# Type safety: ignore that we're using Mock homeservers here.
|
||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
|
||||
db_conn.commit()
|
||||
|
||||
return store
|
||||
|
@ -724,7 +735,9 @@ class Porter(object):
|
|||
except Exception as e:
|
||||
global end_error_exec_info
|
||||
end_error = str(e)
|
||||
end_error_exec_info = sys.exc_info()
|
||||
# Type safety: we're in an exception handler, so the exc_info() tuple
|
||||
# will not be (None, None, None).
|
||||
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
|
||||
logger.exception("")
|
||||
finally:
|
||||
reactor.stop()
|
||||
|
@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
|
|||
curses.init_pair(1, curses.COLOR_RED, -1)
|
||||
curses.init_pair(2, curses.COLOR_GREEN, -1)
|
||||
|
||||
self.last_update = 0
|
||||
self.last_update = 0.0
|
||||
|
||||
self.finished = False
|
||||
|
||||
|
@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
|
|||
left_margin = 5
|
||||
middle_space = 1
|
||||
|
||||
items = self.tables.items()
|
||||
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
|
||||
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
|
||||
|
||||
for i, (table, data) in enumerate(items):
|
||||
if i + 2 >= rows:
|
||||
|
@ -1179,15 +1191,11 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
|
||||
if args.curses:
|
||||
logging_config["filename"] = "port-synapse.log"
|
||||
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
filename="port-synapse.log" if args.curses else None,
|
||||
)
|
||||
|
||||
sqlite_config = {
|
||||
"name": "sqlite3",
|
||||
|
@ -1218,6 +1226,7 @@ def main():
|
|||
config.parse_config_dict(hs_config, "", "")
|
||||
|
||||
def start(stdscr=None):
|
||||
progress: Progress
|
||||
if stdscr:
|
||||
progress = CursesProgress(stdscr)
|
||||
else:
|
||||
|
|
|
@ -16,22 +16,27 @@
|
|||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor as reactor_
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.types import ISynapseReactor
|
||||
|
||||
# Cast safety: Twisted does some naughty magic which replaces the
|
||||
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||
reactor = cast(ISynapseReactor, reactor_)
|
||||
logger = logging.getLogger("update_database")
|
||||
|
||||
|
||||
class MockHomeserver(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super(MockHomeserver, self).__init__(
|
||||
|
@ -85,12 +90,10 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
}
|
||||
|
||||
logging.basicConfig(**logging_config)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.v else logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# Load, process and sanity-check the config.
|
||||
hs_config = yaml.safe_load(args.database_config)
|
||||
|
|
Loading…
Reference in New Issue