Use better file consumer

This commit is contained in:
Erik Johnston 2018-01-17 16:56:23 +00:00
parent 4a53f3a3e8
commit 2cf6a7bc20
1 changed files with 10 additions and 5 deletions

View File

@ -15,10 +15,10 @@
from twisted.internet import defer, threads from twisted.internet import defer, threads
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.protocols.ftp import FileConsumer # This isn't FTP specific
from ._base import Responder from ._base import Responder
from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
import contextlib import contextlib
@ -27,6 +27,7 @@ import logging
import shutil import shutil
import sys import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -168,12 +169,17 @@ class MediaStorage(object):
if os.path.exists(local_path): if os.path.exists(local_path):
defer.returnValue(local_path) defer.returnValue(local_path)
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = yield provider.fetch(path, file_info) res = yield provider.fetch(path, file_info)
if res: if res:
with res: with res:
with open(local_path, "w") as f: consumer = BackgroundFileConsumer(open(local_path, "w"))
res.write_to_consumer(FileConsumer(f)) yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path) defer.returnValue(local_path)
raise Exception("file could not be found") raise Exception("file could not be found")
@ -247,9 +253,8 @@ class FileResponder(Responder):
def __init__(self, open_file): def __init__(self, open_file):
self.open_file = open_file self.open_file = open_file
@defer.inlineCallbacks
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer):
yield FileSender().beginFileTransfer(self.open_file, consumer) return FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close() self.open_file.close()