Move media admin store functions to worker store

This commit is contained in:
Erik Johnston 2020-01-08 13:20:43 +00:00
parent 235d977e1f
commit 3cf7d6d5b6
1 changed files with 120 additions and 120 deletions

View File

@ -366,6 +366,126 @@ class RoomWorkerStore(SQLBaseStore):
defer.returnValue(row)
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.db.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
local_media_mxcs = []
remote_media_mxcs = []
while next_token:
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
event_json = json.loads(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hs.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@ -810,126 +930,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
(room_id,),
)
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.db.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
local_media_mxcs = []
remote_media_mxcs = []
while next_token:
sql = """
SELECT stream_ordering, json FROM events
JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
event_json = json.loads(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hs.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
@defer.inlineCallbacks
def get_rooms_for_retention_period_in_range(
self, min_ms, max_ms, include_null=False