more work saved

This commit is contained in:
Trenton H 2024-04-09 13:53:59 -07:00
parent 5f18e4c9bc
commit f6989e99df
6 changed files with 241 additions and 179 deletions

View File

@ -17,6 +17,7 @@ from filelock import FileLock
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from documents.classifier import load_classifier from documents.classifier import load_classifier
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentMetadataOverrides
from documents.file_handling import create_source_path_directory from documents.file_handling import create_source_path_directory
from documents.file_handling import generate_unique_filename from documents.file_handling import generate_unique_filename
@ -42,6 +43,7 @@ from documents.plugins.base import AlwaysRunPluginMixin
from documents.plugins.base import ConsumeTaskPlugin from documents.plugins.base import ConsumeTaskPlugin
from documents.plugins.base import NoCleanupPluginMixin from documents.plugins.base import NoCleanupPluginMixin
from documents.plugins.base import NoSetupPluginMixin from documents.plugins.base import NoSetupPluginMixin
from documents.plugins.helpers import ProgressManager
from documents.signals import document_consumption_finished from documents.signals import document_consumption_finished
from documents.signals import document_consumption_started from documents.signals import document_consumption_started
from documents.utils import copy_basic_file_stats from documents.utils import copy_basic_file_stats
@ -254,6 +256,32 @@ class ConsumerFilePhase(str, Enum):
class ConsumerPlugin(AlwaysRunPluginMixin, ConsumeTaskPlugin, LoggingMixin): class ConsumerPlugin(AlwaysRunPluginMixin, ConsumeTaskPlugin, LoggingMixin):
logging_name = "paperless.consumer" logging_name = "paperless.consumer"
def __init__(
self,
input_doc: ConsumableDocument,
metadata: DocumentMetadataOverrides,
status_mgr: ProgressManager,
base_tmp_dir: Path,
task_id: str,
) -> None:
super().__init__(input_doc, metadata, status_mgr, base_tmp_dir, task_id)
self.original_path = self.input_doc.original_file
self.filename = self.metadata.filename or self.input_doc.original_file.name
self.override_title = self.metadata.title
self.override_correspondent_id = self.metadata.correspondent_id
self.override_document_type_id = self.metadata.document_type_id
self.override_tag_ids = self.metadata.tag_ids
self.override_storage_path_id = self.metadata.storage_path_id
self.override_created = self.metadata.created
self.override_asn = self.metadata.asn
self.override_owner_id = self.metadata.owner_id
self.override_view_users = self.metadata.view_users
self.override_view_groups = self.metadata.view_groups
self.override_change_users = self.metadata.change_users
self.override_change_groups = self.metadata.change_groups
self.override_custom_field_ids = self.metadata.custom_field_ids
def setup(self) -> None: def setup(self) -> None:
pass pass
@ -474,22 +502,6 @@ class ConsumerPlugin(AlwaysRunPluginMixin, ConsumeTaskPlugin, LoggingMixin):
Return the document object if it was successfully created. Return the document object if it was successfully created.
""" """
self.original_path = self.input_doc.original_file
self.filename = self.metadata.filename or self.input_doc.original_file.name
self.override_title = self.metadata.title
self.override_correspondent_id = self.metadata.correspondent_id
self.override_document_type_id = self.metadata.document_type_id
self.override_tag_ids = self.metadata.tag_ids
self.override_storage_path_id = self.metadata.storage_path_id
self.override_created = self.metadata.created
self.override_asn = self.metadata.asn
self.override_owner_id = self.metadata.owner_id
self.override_view_users = self.metadata.view_users
self.override_view_groups = self.metadata.view_groups
self.override_change_users = self.metadata.change_users
self.override_change_groups = self.metadata.change_groups
self.override_custom_field_ids = self.metadata.custom_field_ids
self._send_progress( self._send_progress(
0, 0,
100, 100,

View File

@ -162,6 +162,8 @@ def consume_file(
finally: finally:
plugin.cleanup() plugin.cleanup()
return msg
@shared_task @shared_task
def sanity_check(): def sanity_check():

View File

@ -4,11 +4,9 @@ import re
import shutil import shutil
import stat import stat
import tempfile import tempfile
import uuid
import zoneinfo import zoneinfo
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from unittest import TestCase as StdLibTestCase
from unittest import mock from unittest import mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -23,10 +21,7 @@ from guardian.core import ObjectPermissionChecker
from documents.consumer import ConsumerError from documents.consumer import ConsumerError
from documents.consumer import ConsumerFilePhase from documents.consumer import ConsumerFilePhase
from documents.consumer import ConsumerPlugin
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import CustomField from documents.models import CustomField
from documents.models import Document from documents.models import Document
@ -38,11 +33,11 @@ from documents.parsers import DocumentParser
from documents.parsers import ParseError from documents.parsers import ParseError
from documents.tasks import sanity_check from documents.tasks import sanity_check
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import DummyProgressManager
from documents.tests.utils import FileSystemAssertsMixin from documents.tests.utils import FileSystemAssertsMixin
from documents.tests.utils import GetConsumerMixin
class TestAttributes(TestCase): class TestAttributes(StdLibTestCase):
TAGS = ("tag1", "tag2", "tag3") TAGS = ("tag1", "tag2", "tag3")
def _test_guess_attributes_from_name(self, filename, sender, title, tags): def _test_guess_attributes_from_name(self, filename, sender, title, tags):
@ -253,7 +248,12 @@ def fake_magic_from_file(file, mime=False):
@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file) @mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase): class TestConsumer(
DirectoriesMixin,
FileSystemAssertsMixin,
GetConsumerMixin,
TestCase,
):
def _assert_first_last_send_progress( def _assert_first_last_send_progress(
self, self,
first_status=ConsumerFilePhase.STARTED, first_status=ConsumerFilePhase.STARTED,
@ -293,26 +293,6 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
): ):
return FaultyGenericExceptionParser(logging_group, self.dirs.scratch_dir) return FaultyGenericExceptionParser(logging_group, self.dirs.scratch_dir)
@contextmanager
def get_consumer(
self,
filepath: Path,
overrides: DocumentMetadataOverrides | None = None,
source: DocumentSource = DocumentSource.ConsumeFolder,
) -> Generator[ConsumerPlugin, None, None]:
# Store this for verification
self.status = DummyProgressManager(filepath.name, None)
reader = ConsumerPlugin(
ConsumableDocument(source, original_file=filepath),
overrides or DocumentMetadataOverrides(),
self.status, # type: ignore
self.dirs.scratch_dir,
"task-id",
)
reader.setup()
yield reader
reader.cleanup()
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -723,7 +703,7 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
self._assert_first_last_send_progress(last_status="FAILED") self._assert_first_last_send_progress(last_status="FAILED")
@mock.patch("documents.consumer.Consumer._write") @mock.patch("documents.consumer.ConsumerPlugin._write")
def testPostSaveError(self, m): def testPostSaveError(self, m):
filename = self.get_test_file() filename = self.get_test_file()
m.side_effect = OSError("NO.") m.side_effect = OSError("NO.")
@ -745,9 +725,14 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
@override_settings(FILENAME_FORMAT="{correspondent}/{title}") @override_settings(FILENAME_FORMAT="{correspondent}/{title}")
def testFilenameHandling(self): def testFilenameHandling(self):
filename = self.get_test_file()
document = self.consumer.try_consume_file(filename, override_title="new docs") with self.get_consumer(
self.get_test_file(),
DocumentMetadataOverrides(title="new docs"),
) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual(document.title, "new docs") self.assertEqual(document.title, "new docs")
self.assertEqual(document.filename, "none/new docs.pdf") self.assertEqual(document.filename, "none/new docs.pdf")
@ -767,11 +752,15 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
m.side_effect = lambda f, archive_filename=False: get_filename() m.side_effect = lambda f, archive_filename=False: get_filename()
filename = self.get_test_file()
Tag.objects.create(name="test", is_inbox_tag=True) Tag.objects.create(name="test", is_inbox_tag=True)
document = self.consumer.try_consume_file(filename, override_title="new docs") with self.get_consumer(
self.get_test_file(),
DocumentMetadataOverrides(title="new docs"),
) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual(document.title, "new docs") self.assertEqual(document.title, "new docs")
self.assertIsNotNone(document.title) self.assertIsNotNone(document.title)
@ -798,7 +787,10 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
m.return_value.predict_document_type.return_value = dtype.pk m.return_value.predict_document_type.return_value = dtype.pk
m.return_value.predict_tags.return_value = [t1.pk] m.return_value.predict_tags.return_value = [t1.pk]
document = self.consumer.try_consume_file(self.get_test_file()) with self.get_consumer(self.get_test_file()) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual(document.correspondent, correspondent) self.assertEqual(document.correspondent, correspondent)
self.assertEqual(document.document_type, dtype) self.assertEqual(document.document_type, dtype)
@ -811,18 +803,24 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
def test_delete_duplicate(self): def test_delete_duplicate(self):
dst = self.get_test_file() dst = self.get_test_file()
self.assertIsFile(dst) self.assertIsFile(dst)
doc = self.consumer.try_consume_file(dst)
with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self._assert_first_last_send_progress() self._assert_first_last_send_progress()
self.assertIsNotFile(dst) self.assertIsNotFile(dst)
self.assertIsNotNone(doc) self.assertIsNotNone(document)
self._send_progress.reset_mock()
dst = self.get_test_file() dst = self.get_test_file()
self.assertIsFile(dst) self.assertIsFile(dst)
self.assertRaises(ConsumerError, self.consumer.try_consume_file, dst)
with self.get_consumer(dst) as consumer:
with self.assertRaises(ConsumerError):
consumer.run()
self.assertIsNotFile(dst) self.assertIsNotFile(dst)
self._assert_first_last_send_progress(last_status="FAILED") self._assert_first_last_send_progress(last_status="FAILED")
@ -830,31 +828,40 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
def test_no_delete_duplicate(self): def test_no_delete_duplicate(self):
dst = self.get_test_file() dst = self.get_test_file()
self.assertIsFile(dst) self.assertIsFile(dst)
doc = self.consumer.try_consume_file(dst)
with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self._assert_first_last_send_progress()
self.assertIsNotFile(dst) self.assertIsNotFile(dst)
self.assertIsNotNone(doc) self.assertIsNotNone(document)
dst = self.get_test_file() dst = self.get_test_file()
self.assertIsFile(dst) self.assertIsFile(dst)
self.assertRaises(ConsumerError, self.consumer.try_consume_file, dst)
self.assertIsFile(dst)
with self.get_consumer(dst) as consumer:
with self.assertRaises(ConsumerError):
consumer.run()
self.assertIsFile(dst)
self._assert_first_last_send_progress(last_status="FAILED") self._assert_first_last_send_progress(last_status="FAILED")
@override_settings(FILENAME_FORMAT="{title}") @override_settings(FILENAME_FORMAT="{title}")
@mock.patch("documents.parsers.document_consumer_declaration.send") @mock.patch("documents.parsers.document_consumer_declaration.send")
def test_similar_filenames(self, m): def test_similar_filenames(self, m):
shutil.copy( shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), os.path.join(Path(__file__).parent, "samples", "simple.pdf"),
os.path.join(settings.CONSUMPTION_DIR, "simple.pdf"), os.path.join(settings.CONSUMPTION_DIR, "simple.pdf"),
) )
shutil.copy( shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.png"), os.path.join(Path(__file__).parent, "samples", "simple.png"),
os.path.join(settings.CONSUMPTION_DIR, "simple.png"), os.path.join(settings.CONSUMPTION_DIR, "simple.png"),
) )
shutil.copy( shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple-noalpha.png"), os.path.join(Path(__file__).parent, "samples", "simple-noalpha.png"),
os.path.join(settings.CONSUMPTION_DIR, "simple.png.pdf"), os.path.join(settings.CONSUMPTION_DIR, "simple.png.pdf"),
) )
m.return_value = [ m.return_value = [
@ -867,15 +874,21 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
}, },
), ),
] ]
doc1 = self.consumer.try_consume_file(
os.path.join(settings.CONSUMPTION_DIR, "simple.png"), with self.get_consumer(settings.CONSUMPTION_DIR / "simple.png") as consumer:
) consumer.run()
doc2 = self.consumer.try_consume_file(
os.path.join(settings.CONSUMPTION_DIR, "simple.pdf"), doc1 = Document.objects.last()
)
doc3 = self.consumer.try_consume_file( with self.get_consumer(settings.CONSUMPTION_DIR / "simple.pdf") as consumer:
os.path.join(settings.CONSUMPTION_DIR, "simple.png.pdf"), consumer.run()
)
doc2 = Document.objects.last()
with self.get_consumer(settings.CONSUMPTION_DIR / "simple.png.pdf") as consumer:
consumer.run()
doc3 = Document.objects.last()
self.assertEqual(doc1.filename, "simple.png") self.assertEqual(doc1.filename, "simple.png")
self.assertEqual(doc1.archive_filename, "simple.pdf") self.assertEqual(doc1.archive_filename, "simple.pdf")
@ -888,7 +901,7 @@ class TestConsumer(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
@mock.patch("documents.consumer.magic.from_file", fake_magic_from_file) @mock.patch("documents.consumer.magic.from_file", fake_magic_from_file)
class TestConsumerCreatedDate(DirectoriesMixin, TestCase): class TestConsumerCreatedDate(DirectoriesMixin, GetConsumerMixin, TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -900,17 +913,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
THEN: THEN:
- Should parse the date from the file content - Should parse the date from the file content
""" """
src = os.path.join( src = (
os.path.dirname(__file__), Path(__file__).parent
"samples", / "samples"
"documents", / "documents"
"originals", / "originals"
"0000005.pdf", / "0000005.pdf"
) )
dst = os.path.join(self.dirs.scratch_dir, "sample.pdf") dst = self.dirs.scratch_dir / "sample.pdf"
shutil.copy(src, dst) shutil.copy(src, dst)
document = self.consumer.try_consume_file(dst) with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual( self.assertEqual(
document.created, document.created,
@ -927,17 +943,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
THEN: THEN:
- Should parse the date from the filename - Should parse the date from the filename
""" """
src = os.path.join( src = (
os.path.dirname(__file__), Path(__file__).parent
"samples", / "samples"
"documents", / "documents"
"originals", / "originals"
"0000005.pdf", / "0000005.pdf"
) )
dst = os.path.join(self.dirs.scratch_dir, "Scan - 2022-02-01.pdf") dst = self.dirs.scratch_dir / "Scan - 2022-02-01.pdf"
shutil.copy(src, dst) shutil.copy(src, dst)
document = self.consumer.try_consume_file(dst) with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual( self.assertEqual(
document.created, document.created,
@ -954,17 +973,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
THEN: THEN:
- Should parse the date from the content - Should parse the date from the content
""" """
src = os.path.join( src = (
os.path.dirname(__file__), Path(__file__).parent
"samples", / "samples"
"documents", / "documents"
"originals", / "originals"
"0000005.pdf", / "0000005.pdf"
) )
dst = os.path.join(self.dirs.scratch_dir, "Scan - 2022-02-01.pdf") dst = self.dirs.scratch_dir / "Scan - 2022-02-01.pdf"
shutil.copy(src, dst) shutil.copy(src, dst)
document = self.consumer.try_consume_file(dst) with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual( self.assertEqual(
document.created, document.created,
@ -983,17 +1005,20 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
THEN: THEN:
- Should parse the date from the filename - Should parse the date from the filename
""" """
src = os.path.join( src = (
os.path.dirname(__file__), Path(__file__).parent
"samples", / "samples"
"documents", / "documents"
"originals", / "originals"
"0000006.pdf", / "0000006.pdf"
) )
dst = os.path.join(self.dirs.scratch_dir, "0000006.pdf") dst = self.dirs.scratch_dir / "0000006.pdf"
shutil.copy(src, dst) shutil.copy(src, dst)
document = self.consumer.try_consume_file(dst) with self.get_consumer(dst) as consumer:
consumer.run()
document = Document.objects.first()
self.assertEqual( self.assertEqual(
document.created, document.created,
@ -1001,41 +1026,40 @@ class TestConsumerCreatedDate(DirectoriesMixin, TestCase):
) )
class PreConsumeTestCase(TestCase): class PreConsumeTestCase(DirectoriesMixin, GetConsumerMixin, TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# this prevents websocket message reports during testing. super().setUp()
patcher = mock.patch("documents.consumer.Consumer._send_progress") src = (
self._send_progress = patcher.start() Path(__file__).parent
self.addCleanup(patcher.stop) / "samples"
/ "documents"
return super().setUp() / "originals"
/ "0000005.pdf"
)
self.test_file = self.dirs.scratch_dir / "sample.pdf"
shutil.copy(src, self.test_file)
@mock.patch("documents.consumer.run_subprocess") @mock.patch("documents.consumer.run_subprocess")
@override_settings(PRE_CONSUME_SCRIPT=None) @override_settings(PRE_CONSUME_SCRIPT=None)
def test_no_pre_consume_script(self, m): def test_no_pre_consume_script(self, m):
c = Consumer() with self.get_consumer(self.test_file) as c:
c.working_copy = "path-to-file" c.run()
c.run_pre_consume_script()
m.assert_not_called() m.assert_not_called()
@mock.patch("documents.consumer.run_subprocess") @mock.patch("documents.consumer.run_subprocess")
@mock.patch("documents.consumer.Consumer._send_progress")
@override_settings(PRE_CONSUME_SCRIPT="does-not-exist") @override_settings(PRE_CONSUME_SCRIPT="does-not-exist")
def test_pre_consume_script_not_found(self, m, m2): def test_pre_consume_script_not_found(self, m):
c = Consumer() with self.get_consumer(self.test_file) as c:
c.filename = "somefile.pdf"
c.working_copy = "path-to-file" self.assertRaises(ConsumerError, c.run)
self.assertRaises(ConsumerError, c.run_pre_consume_script) m.assert_not_called()
@mock.patch("documents.consumer.run_subprocess") @mock.patch("documents.consumer.run_subprocess")
def test_pre_consume_script(self, m): def test_pre_consume_script(self, m):
with tempfile.NamedTemporaryFile() as script: with tempfile.NamedTemporaryFile() as script:
with override_settings(PRE_CONSUME_SCRIPT=script.name): with override_settings(PRE_CONSUME_SCRIPT=script.name):
c = Consumer() with self.get_consumer(self.test_file) as c:
c.original_path = "path-to-file" c.run()
c.working_copy = "/tmp/somewhere/path-to-file"
c.task_id = str(uuid.uuid4())
c.run_pre_consume_script()
m.assert_called_once() m.assert_called_once()
@ -1045,11 +1069,11 @@ class PreConsumeTestCase(TestCase):
environment = args[1] environment = args[1]
self.assertEqual(command[0], script.name) self.assertEqual(command[0], script.name)
self.assertEqual(command[1], "path-to-file") self.assertEqual(command[1], str(self.test_file))
subset = { subset = {
"DOCUMENT_SOURCE_PATH": c.original_path, "DOCUMENT_SOURCE_PATH": str(c.original_path),
"DOCUMENT_WORKING_PATH": c.working_copy, "DOCUMENT_WORKING_PATH": str(c.working_copy),
"TASK_ID": c.task_id, "TASK_ID": c.task_id,
} }
self.assertDictEqual(environment, {**environment, **subset}) self.assertDictEqual(environment, {**environment, **subset})
@ -1076,10 +1100,8 @@ class PreConsumeTestCase(TestCase):
with override_settings(PRE_CONSUME_SCRIPT=script.name): with override_settings(PRE_CONSUME_SCRIPT=script.name):
with self.assertLogs("paperless.consumer", level="INFO") as cm: with self.assertLogs("paperless.consumer", level="INFO") as cm:
c = Consumer() with self.get_consumer(self.test_file) as c:
c.working_copy = "path-to-file" c.run()
c.run_pre_consume_script()
self.assertIn( self.assertIn(
"INFO:paperless.consumer:This message goes to stdout", "INFO:paperless.consumer:This message goes to stdout",
cm.output, cm.output,
@ -1109,22 +1131,25 @@ class PreConsumeTestCase(TestCase):
os.chmod(script.name, st.st_mode | stat.S_IEXEC) os.chmod(script.name, st.st_mode | stat.S_IEXEC)
with override_settings(PRE_CONSUME_SCRIPT=script.name): with override_settings(PRE_CONSUME_SCRIPT=script.name):
c = Consumer() with self.get_consumer(self.test_file) as c:
c.working_copy = "path-to-file"
self.assertRaises( self.assertRaises(
ConsumerError, ConsumerError,
c.run_pre_consume_script, c.run,
) )
class PostConsumeTestCase(TestCase): class PostConsumeTestCase(DirectoriesMixin, GetConsumerMixin, TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# this prevents websocket message reports during testing. super().setUp()
patcher = mock.patch("documents.consumer.Consumer._send_progress") src = (
self._send_progress = patcher.start() Path(__file__).parent
self.addCleanup(patcher.stop) / "samples"
/ "documents"
return super().setUp() / "originals"
/ "0000005.pdf"
)
self.test_file = self.dirs.scratch_dir / "sample.pdf"
shutil.copy(src, self.test_file)
@mock.patch("documents.consumer.run_subprocess") @mock.patch("documents.consumer.run_subprocess")
@override_settings(POST_CONSUME_SCRIPT=None) @override_settings(POST_CONSUME_SCRIPT=None)
@ -1135,21 +1160,17 @@ class PostConsumeTestCase(TestCase):
doc.tags.add(tag1) doc.tags.add(tag1)
doc.tags.add(tag2) doc.tags.add(tag2)
Consumer().run_post_consume_script(doc) with self.get_consumer(self.test_file) as consumer:
consumer.run_post_consume_script(doc)
m.assert_not_called() m.assert_not_called()
@override_settings(POST_CONSUME_SCRIPT="does-not-exist") @override_settings(POST_CONSUME_SCRIPT="does-not-exist")
@mock.patch("documents.consumer.Consumer._send_progress") def test_post_consume_script_not_found(self):
def test_post_consume_script_not_found(self, m):
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
c = Consumer()
c.filename = "somefile.pdf" with self.get_consumer(self.test_file) as consumer:
self.assertRaises( with self.assertRaises(ConsumerError):
ConsumerError, consumer.run_post_consume_script(doc)
c.run_post_consume_script,
doc,
)
@mock.patch("documents.consumer.run_subprocess") @mock.patch("documents.consumer.run_subprocess")
def test_post_consume_script_simple(self, m): def test_post_consume_script_simple(self, m):
@ -1157,7 +1178,8 @@ class PostConsumeTestCase(TestCase):
with override_settings(POST_CONSUME_SCRIPT=script.name): with override_settings(POST_CONSUME_SCRIPT=script.name):
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
Consumer().run_post_consume_script(doc) with self.get_consumer(self.test_file) as consumer:
consumer.run_post_consume_script(doc)
m.assert_called_once() m.assert_called_once()
@ -1176,8 +1198,7 @@ class PostConsumeTestCase(TestCase):
doc.tags.add(tag1) doc.tags.add(tag1)
doc.tags.add(tag2) doc.tags.add(tag2)
consumer = Consumer() with self.get_consumer(self.test_file) as consumer:
consumer.task_id = str(uuid.uuid4())
consumer.run_post_consume_script(doc) consumer.run_post_consume_script(doc)
m.assert_called_once() m.assert_called_once()
@ -1225,8 +1246,8 @@ class PostConsumeTestCase(TestCase):
os.chmod(script.name, st.st_mode | stat.S_IEXEC) os.chmod(script.name, st.st_mode | stat.S_IEXEC)
with override_settings(POST_CONSUME_SCRIPT=script.name): with override_settings(POST_CONSUME_SCRIPT=script.name):
c = Consumer()
doc = Document.objects.create(title="Test", mime_type="application/pdf") doc = Document.objects.create(title="Test", mime_type="application/pdf")
c.path = "path-to-file" with self.get_consumer(self.test_file) as consumer:
with self.assertRaises(ConsumerError): with self.assertRaises(ConsumerError):
c.run_post_consume_script(doc) consumer.run_post_consume_script(doc)

View File

@ -46,7 +46,7 @@ class TestDoubleSided(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
with mock.patch( with mock.patch(
"documents.tasks.ProgressManager", "documents.tasks.ProgressManager",
DummyProgressManager, DummyProgressManager,
), mock.patch("documents.consumer.async_to_sync"): ):
msg = tasks.consume_file( msg = tasks.consume_file(
ConsumableDocument( ConsumableDocument(
source=DocumentSource.ConsumeFolder, source=DocumentSource.ConsumeFolder,

View File

@ -88,7 +88,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
return super().setUp() return super().setUp()
@mock.patch("documents.consumer.Consumer.try_consume_file") @mock.patch("documents.consumer.ConsumerPlugin.run")
def test_workflow_match(self, m): def test_workflow_match(self, m):
""" """
GIVEN: GIVEN:
@ -1467,7 +1467,7 @@ class TestWorkflows(DirectoriesMixin, FileSystemAssertsMixin, APITestCase):
expected_str = f"Document matched {trigger} from {w}" expected_str = f"Document matched {trigger} from {w}"
self.assertIn(expected_str, info) self.assertIn(expected_str, info)
@mock.patch("documents.consumer.Consumer.try_consume_file") @mock.patch("documents.consumer.ConsumerPlugin.run")
def test_removal_action_document_consumed_removeall(self, m): def test_removal_action_document_consumed_removeall(self, m):
""" """
GIVEN: GIVEN:

View File

@ -3,6 +3,7 @@ import tempfile
import time import time
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from collections.abc import Generator
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from os import PathLike from os import PathLike
@ -21,8 +22,10 @@ from django.db.migrations.executor import MigrationExecutor
from django.test import TransactionTestCase from django.test import TransactionTestCase
from django.test import override_settings from django.test import override_settings
from documents.consumer import ConsumerPlugin
from documents.data_models import ConsumableDocument from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource
from documents.parsers import ParseError from documents.parsers import ParseError
from documents.plugins.helpers import ProgressStatusOptions from documents.plugins.helpers import ProgressStatusOptions
@ -326,6 +329,30 @@ class SampleDirMixin:
BARCODE_SAMPLE_DIR = SAMPLE_DIR / "barcodes" BARCODE_SAMPLE_DIR = SAMPLE_DIR / "barcodes"
class GetConsumerMixin:
@contextmanager
def get_consumer(
self,
filepath: Path,
overrides: DocumentMetadataOverrides | None = None,
source: DocumentSource = DocumentSource.ConsumeFolder,
) -> Generator[ConsumerPlugin, None, None]:
# Store this for verification
self.status = DummyProgressManager(filepath.name, None)
reader = ConsumerPlugin(
ConsumableDocument(source, original_file=filepath),
overrides or DocumentMetadataOverrides(),
self.status, # type: ignore
self.dirs.scratch_dir,
"task-id",
)
reader.setup()
try:
yield reader
finally:
reader.cleanup()
class DummyProgressManager: class DummyProgressManager:
""" """
A dummy handler for progress management that doesn't actually try to A dummy handler for progress management that doesn't actually try to