Fix module API's `get_user_ip_and_agents` function when run on workers (#11112)

This commit is contained in:
Sean Quah 2021-10-25 13:01:04 +01:00 committed by GitHub
parent 2b82ec425f
commit 85a09f8b8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 41 deletions

1
changelog.d/11112.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time.

View File

@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
@ -61,6 +62,7 @@ from synapse.util import Clock
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.app.generic_worker import GenericWorkerSlavedStore
from synapse.server import HomeServer
"""
@ -111,7 +113,9 @@ class ModuleApi:
def __init__(self, hs: "HomeServer", auth_handler):
self._hs = hs
self._store = hs.get_datastore()
# TODO: Fix this type hint once the types for the data stores have been ironed
# out.
self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname

View File

@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
The result might be slightly out of date as client IPs are inserted in batches.
Args:
user: The user for which to fetch IP addresses and user agents.
since_ts: The timestamp after which to fetch IP addresses and user agents,
in milliseconds.
Returns:
A list of dictionaries, each containing:
* `access_token`: The access token used.
* `ip`: The IP address used.
* `user_agent`: The last user agent seen for this access token and IP
address combination.
* `last_seen`: The timestamp at which this access token and IP address
combination was last seen, in milliseconds.
Only the latest user agent for each access token and IP address combination
is available.
"""
user_id = user.to_string()
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
txn.execute(
"""
SELECT access_token, ip, user_agent, last_seen FROM user_ips
WHERE last_seen >= ? AND user_id = ?
ORDER BY last_seen
DESC
""",
(since_ts, user_id),
)
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
desc="get_user_ip_and_agents", func=get_recent
)
return [
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for access_token, ip, user_agent, last_seen in rows
]
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
async def get_user_ip_and_agents(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""
Fetch IP/User Agent connection since a given timestamp.
"""
user_id = user.to_string()
results: Dict[Tuple[str, str], Tuple[str, int]] = {}
"""Fetch the IPs and user agents for a user since the given timestamp.
Args:
user: The user for which to fetch IP addresses and user agents.
since_ts: The timestamp after which to fetch IP addresses and user agents,
in milliseconds.
Returns:
A list of dictionaries, each containing:
* `access_token`: The access token used.
* `ip`: The IP address used.
* `user_agent`: The last user agent seen for this access token and IP
address combination.
* `last_seen`: The timestamp at which this access token and IP address
combination was last seen, in milliseconds.
Only the latest user agent for each access token and IP address combination
is available.
"""
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
for connection in await super().get_user_ip_and_agents(user, since_ts)
}
# Overlay data that is pending insertion on top of the results from the
# database.
user_id = user.to_string()
for key in self._batch_row_update:
(
uid,
access_token,
ip,
) = key
uid, access_token, ip = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
if last_seen >= since_ts:
results[(access_token, ip)] = (user_agent, last_seen)
results[(access_token, ip)] = {
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
txn.execute(
"""
SELECT access_token, ip, user_agent, last_seen FROM user_ips
WHERE last_seen >= ? AND user_id = ?
ORDER BY last_seen
DESC
""",
(since_ts, user_id),
)
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
desc="get_user_ip_and_agents", func=get_recent
)
results.update(
((access_token, ip), (user_agent, last_seen))
for access_token, ip, user_agent, last_seen in rows
)
return [
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in results.items()
]
return list(results.values())