Allow only one download for a given image at a time, so that we don't end up downloading the same image twice if two clients request a remote image at the same time

This commit is contained in:
Mark Haines 2014-12-11 16:48:11 +00:00
parent d80d505b1f
commit 03d9024cbc
3 changed files with 29 additions and 20 deletions

View File

@ -45,6 +45,7 @@ class BaseMediaResource(Resource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths self.filepaths = filepaths
self.downloads = {}
@staticmethod @staticmethod
def catch_errors(request_handler): def catch_errors(request_handler):
@ -128,6 +129,28 @@ class BaseMediaResource(Resource):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
def _get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return download
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id): def _download_remote_file(self, server_name, media_id):
file_id = random_string(24) file_id = random_string(24)
@ -231,7 +254,7 @@ class BaseMediaResource(Resource):
if m_width * m_height >= self.max_image_pixels: if m_width * m_height >= self.max_image_pixels:
logger.info( logger.info(
"Image too large to thumbnail %r x %r > %r" "Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels m_width, m_height, self.max_image_pixels
) )
return return
@ -294,7 +317,7 @@ class BaseMediaResource(Resource):
if m_width * m_height >= self.max_image_pixels: if m_width * m_height >= self.max_image_pixels:
logger.info( logger.info(
"Image too large to thumbnail %r x %r > %r" "Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels m_width, m_height, self.max_image_pixels
) )
return return

View File

@ -56,14 +56,7 @@ class DownloadResource(BaseMediaResource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id): def _respond_remote_file(self, request, server_name, media_id):
media_info = yield self.store.get_cached_remote_media( media_info = yield self._get_remote_media(server_name, media_id)
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
media_type = media_info["media_type"] media_type = media_info["media_type"]
filesystem_id = media_info["filesystem_id"] filesystem_id = media_info["filesystem_id"]

View File

@ -83,16 +83,9 @@ class ThumbnailResource(BaseMediaResource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width, def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type): height, method, m_type):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead. # We should proxy the thumbnail from the remote server instead.
media_info = yield self._download_remote_file( media_info = yield self._get_remote_media(server_name, media_id)
server_name, media_id
)
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id, server_name, media_id,