Add type hints to `synapse.storage.databases.main.client_ips` (#10972)

This commit is contained in:
Sean Quah 2021-10-12 13:50:34 +01:00 committed by GitHub
parent a18c568516
commit 36224e056a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 121 additions and 45 deletions

1
changelog.d/10972.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.storage.databases.main.client_ips`.

View File

@ -53,6 +53,7 @@ files =
synapse/storage/_base.py, synapse/storage/_base.py,
synapse/storage/background_updates.py, synapse/storage/background_updates.py,
synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/client_ips.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/keys.py, synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/pusher.py,
@ -108,6 +109,9 @@ disallow_untyped_defs = True
[mypy-synapse.state.*] [mypy-synapse.state.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.client_ips]
disallow_untyped_defs = True
[mypy-synapse.storage.util.*] [mypy-synapse.storage.util.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -14,7 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips( def _update_device_from_client_ips(
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
) -> None: ) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})

View File

@ -773,9 +773,9 @@ class ModuleApi:
# Sanitize some of the data. We don't want to return tokens. # Sanitize some of the data. We don't want to return tokens.
return [ return [
UserIpAndAgent( UserIpAndAgent(
ip=str(data["ip"]), ip=data["ip"],
user_agent=str(data["user_agent"]), user_agent=data["user_agent"],
last_seen=int(data["last_seen"]), last_seen=data["last_seen"],
) )
for data in raw_data for data in raw_data
] ]

View File

@ -13,14 +13,26 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.database import (
from synapse.types import UserID DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000 LAST_SEEN_GRANULARITY = 120 * 1000
class DeviceLastConnectionInfo(TypedDict):
"""Metadata for the last connection seen for a user and device combination"""
# These types must match the columns in the `devices` table
user_id: str
device_id: str
ip: Optional[str]
user_agent: Optional[str]
last_seen: Optional[int]
class LastConnectionInfo(TypedDict):
"""Metadata for the last connection seen for an access token and IP combination"""
# These types must match the columns in the `user_ips` table
access_token: str
ip: str
user_agent: str
last_seen: int
class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
"devices_last_seen", self._devices_last_seen_update "devices_last_seen", self._devices_last_seen_update
) )
async def _remove_user_ip_nonunique(self, progress, batch_size): async def _remove_user_ip_nonunique(
def f(conn): self, progress: JsonDict, batch_size: int
) -> int:
def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close() txn.close()
@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
) )
return 1 return 1
async def _analyze_user_ip(self, progress, batch_size): async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
# Background update to analyze user_ips table before we run the # Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed # deduplication background update. The table may not have been analyzed
# for ages due to the table locks. # for ages due to the table locks.
# #
# This will lock out the naive upserts to user_ips while it happens, but # This will lock out the naive upserts to user_ips while it happens, but
# the analyze should be quick (28GB table takes ~10s) # the analyze should be quick (28GB table takes ~10s)
def user_ips_analyze(txn): def user_ips_analyze(txn: LoggingTransaction) -> None:
txn.execute("ANALYZE user_ips") txn.execute("ANALYZE user_ips")
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return 1 return 1
async def _remove_user_ip_dupes(self, progress, batch_size): async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
# This works function works by scanning the user_ips table in batches # This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of # based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they # the table to see if there are any duplicates, if there are then they
# are removed and replaced with a suitable row. # are removed and replaced with a suitable row.
# Fetch the start of the batch # Fetch the start of the batch
begin_last_seen = progress.get("last_seen", 0) begin_last_seen: int = progress.get("last_seen", 0)
def get_last_seen(txn): def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
txn.execute( txn.execute(
""" """
SELECT last_seen FROM user_ips SELECT last_seen FROM user_ips
@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
""", """,
(begin_last_seen, batch_size), (begin_last_seen, batch_size),
) )
row = txn.fetchone() row = cast(Optional[Tuple[int]], txn.fetchone())
if row: if row:
return row[0] return row[0]
else: else:
@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
end_last_seen, end_last_seen,
) )
def remove(txn): def remove(txn: LoggingTransaction) -> None:
# This works by looking at all entries in the given time span, and # This works by looking at all entries in the given time span, and
# then for each (user_id, access_token, ip) tuple in that range # then for each (user_id, access_token, ip) tuple in that range
# checking for any duplicates in the rest of the table (via a join). # checking for any duplicates in the rest of the table (via a join).
@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# Define the search space, which requires handling the last batch in # Define the search space, which requires handling the last batch in
# a different way # a different way
args: Tuple[int, ...]
if last: if last:
clause = "? <= last_seen" clause = "? <= last_seen"
args = (begin_last_seen,) args = (begin_last_seen,)
else: else:
assert end_last_seen is not None
clause = "? <= last_seen AND last_seen < ?" clause = "? <= last_seen AND last_seen < ?"
args = (begin_last_seen, end_last_seen) args = (begin_last_seen, end_last_seen)
@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
), ),
args, args,
) )
res = txn.fetchall() res = cast(
List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
)
# We've got some duplicates # We've got some duplicates
for i in res: for i in res:
@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size return batch_size
async def _devices_last_seen_update(self, progress, batch_size): async def _devices_last_seen_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to insert last seen info into devices table""" """Background update to insert last seen info into devices table"""
last_user_id = progress.get("last_user_id", "") last_user_id: str = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "") last_device_id: str = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn): def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
# This consists of two queries: # This consists of two queries:
# #
# 1. The sub-query searches for the next N devices and joins # 1. The sub-query searches for the next N devices and joins
@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# we'll just end up updating the same device row multiple # we'll just end up updating the same device row multiple
# times, which is fine. # times, which is fine.
where_args: List[Union[str, int]]
where_clause, where_args = make_tuple_comparison_clause( where_clause, where_args = make_tuple_comparison_clause(
[("user_id", last_user_id), ("device_id", last_device_id)], [("user_id", last_user_id), ("device_id", last_device_id)],
) )
@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
} }
txn.execute(sql, where_args + [batch_size]) txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall() rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
if not rows: if not rows:
return 0 return 0
@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.server.user_ips_max_age self.user_ips_max_age = hs.config.server.user_ips_max_age
@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self): async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period.""" """Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None: if self.user_ips_max_age is None:
@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
) )
""" """
timestamp = self.clock.time_msec() - self.user_ips_max_age timestamp = self._clock.time_msec() - self.user_ips_max_age
def _prune_old_user_ips_txn(txn): def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
txn.execute(sql, (timestamp,)) txn.execute(sql, (timestamp,))
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
async def get_last_client_ip_by_device( async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str] self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]: ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on. """For each device_id listed, give the user_ip it was last seen on.
The result might be slightly out of date as client IPs are inserted in batches. The result might be slightly out of date as client IPs are inserted in batches.
@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
res = await self.db_pool.simple_select_list( res = cast(
table="devices", List[DeviceLastConnectionInfo],
keyvalues=keyvalues, await self.db_pool.simple_select_list(
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
) )
return {(d["user_id"], d["device_id"]): d for d in res} return {(d["user_id"], d["device_id"]): d for d in res}
class ClientIpStore(ClientIpWorkerStore): class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.client_ip_last_seen = LruCache( # (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
cache_name="client_ip_last_seen", max_size=50000 cache_name="client_ip_last_seen", max_size=50000
) )
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {} self._batch_row_update: Dict[
Tuple[str, str, str], Tuple[str, Optional[str], int]
] = {}
self._client_ip_looper = self._clock.looping_call( self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000 self._update_client_ips_batch, 5 * 1000
@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore):
) )
async def insert_client_ip( async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None self,
): user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: Optional[str],
now: Optional[int] = None,
) -> None:
if not now: if not now:
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user_id, access_token, ip) key = (user_id, access_token, ip)
@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore):
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
) )
def _update_client_ips_batch_txn(self, txn, to_update): def _update_client_ips_batch_txn(
self,
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert not self.database_engine.can_native_upsert
): ):
@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_last_client_ip_by_device( async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str] self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]: ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on """For each device_id listed, give the user_ip it was last seen on
Args: Args:
@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore):
async def get_user_ip_and_agents( async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0 self, user: UserID, since_ts: int = 0
) -> List[Dict[str, Union[str, int]]]: ) -> List[LastConnectionInfo]:
""" """
Fetch IP/User Agent connection since a given timestamp. Fetch IP/User Agent connection since a given timestamp.
""" """
user_id = user.to_string() user_id = user.to_string()
results = {} results: Dict[Tuple[str, str], Tuple[str, int]] = {}
for key in self._batch_row_update: for key in self._batch_row_update:
( (
@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore):
if last_seen >= since_ts: if last_seen >= since_ts:
results[(access_token, ip)] = (user_agent, last_seen) results[(access_token, ip)] = (user_agent, last_seen)
def get_recent(txn): def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
txn.execute( txn.execute(
""" """
SELECT access_token, ip, user_agent, last_seen FROM user_ips SELECT access_token, ip, user_agent, last_seen FROM user_ips
@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore):
""", """,
(since_ts, user_id), (since_ts, user_id),
) )
return txn.fetchall() return cast(List[Tuple[str, str, str, int]], txn.fetchall())
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
desc="get_user_ip_and_agents", func=get_recent desc="get_user_ip_and_agents", func=get_recent