Removes some extra duplication

This commit is contained in:
Trenton H 2024-04-10 08:43:06 -07:00
parent e17c78add6
commit 7dad89332d
3 changed files with 65 additions and 85 deletions

View File

@ -4,7 +4,7 @@ import { environment } from 'src/environments/environment'
import { WebsocketConsumerStatusMessage } from '../data/websocket-consumer-status-message' import { WebsocketConsumerStatusMessage } from '../data/websocket-consumer-status-message'
import { SettingsService } from './settings.service' import { SettingsService } from './settings.service'
// see ConsumerFilePhase in src/documents/consumer.py // see ProgressStatusOptions in src/documents/plugins/helpers.py
export enum FileStatusPhase { export enum FileStatusPhase {
STARTED = 0, STARTED = 0,
UPLOADING = 1, UPLOADING = 1,

View File

@ -45,6 +45,7 @@ from documents.plugins.base import ConsumeTaskPlugin
from documents.plugins.base import NoCleanupPluginMixin from documents.plugins.base import NoCleanupPluginMixin
from documents.plugins.base import NoSetupPluginMixin from documents.plugins.base import NoSetupPluginMixin
from documents.plugins.helpers import ProgressManager from documents.plugins.helpers import ProgressManager
from documents.plugins.helpers import ProgressStatusOptions
from documents.signals import document_consumption_finished from documents.signals import document_consumption_finished
from documents.signals import document_consumption_started from documents.signals import document_consumption_started
from documents.utils import copy_basic_file_stats from documents.utils import copy_basic_file_stats
@ -247,13 +248,6 @@ class ConsumerStatusShortMessage(str, Enum):
FAILED = "failed" FAILED = "failed"
class ConsumerFilePhase(str, Enum):
STARTED = "STARTED"
WORKING = "WORKING"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
class ConsumerPlugin( class ConsumerPlugin(
AlwaysRunPluginMixin, AlwaysRunPluginMixin,
NoSetupPluginMixin, NoSetupPluginMixin,
@ -275,27 +269,13 @@ class ConsumerPlugin(
self.renew_logging_group() self.renew_logging_group()
self.original_path = self.input_doc.original_file
self.filename = self.metadata.filename or self.input_doc.original_file.name self.filename = self.metadata.filename or self.input_doc.original_file.name
self.override_title = self.metadata.title
self.override_correspondent_id = self.metadata.correspondent_id
self.override_document_type_id = self.metadata.document_type_id
self.override_tag_ids = self.metadata.tag_ids
self.override_storage_path_id = self.metadata.storage_path_id
self.override_created = self.metadata.created
self.override_asn = self.metadata.asn
self.override_owner_id = self.metadata.owner_id
self.override_view_users = self.metadata.view_users
self.override_view_groups = self.metadata.view_groups
self.override_change_users = self.metadata.change_users
self.override_change_groups = self.metadata.change_groups
self.override_custom_field_ids = self.metadata.custom_field_ids
def _send_progress( def _send_progress(
self, self,
current_progress: int, current_progress: int,
max_progress: int, max_progress: int,
status: ConsumerFilePhase, status: ProgressStatusOptions,
message: Optional[Union[ConsumerStatusShortMessage, str]] = None, message: Optional[Union[ConsumerStatusShortMessage, str]] = None,
document_id=None, document_id=None,
): # pragma: no cover ): # pragma: no cover
@ -306,7 +286,7 @@ class ConsumerPlugin(
max_progress, max_progress,
extra_args={ extra_args={
"document_id": document_id, "document_id": document_id,
"owner_id": self.override_owner_id if self.override_owner_id else None, "owner_id": self.metadata.owner_id if self.metadata.owner_id else None,
}, },
) )
@ -317,7 +297,7 @@ class ConsumerPlugin(
exc_info=None, exc_info=None,
exception: Optional[Exception] = None, exception: Optional[Exception] = None,
): ):
self._send_progress(100, 100, ConsumerFilePhase.FAILED, message) self._send_progress(100, 100, ProgressStatusOptions.FAILED, message)
self.log.error(log_message or message, exc_info=exc_info) self.log.error(log_message or message, exc_info=exc_info)
raise ConsumerError(f"{self.filename}: {log_message or message}") from exception raise ConsumerError(f"{self.filename}: {log_message or message}") from exception
@ -325,24 +305,24 @@ class ConsumerPlugin(
""" """
Confirm the input file still exists where it should Confirm the input file still exists where it should
""" """
if not os.path.isfile(self.original_path): if not os.path.isfile(self.input_doc.original_file):
self._fail( self._fail(
ConsumerStatusShortMessage.FILE_NOT_FOUND, ConsumerStatusShortMessage.FILE_NOT_FOUND,
f"Cannot consume {self.original_path}: File not found.", f"Cannot consume {self.input_doc.original_file}: File not found.",
) )
def pre_check_duplicate(self): def pre_check_duplicate(self):
""" """
Using the MD5 of the file, check this exact file doesn't already exist Using the MD5 of the file, check this exact file doesn't already exist
""" """
with open(self.original_path, "rb") as f: with open(self.input_doc.original_file, "rb") as f:
checksum = hashlib.md5(f.read()).hexdigest() checksum = hashlib.md5(f.read()).hexdigest()
existing_doc = Document.objects.filter( existing_doc = Document.objects.filter(
Q(checksum=checksum) | Q(archive_checksum=checksum), Q(checksum=checksum) | Q(archive_checksum=checksum),
) )
if existing_doc.exists(): if existing_doc.exists():
if settings.CONSUMER_DELETE_DUPLICATES: if settings.CONSUMER_DELETE_DUPLICATES:
os.unlink(self.original_path) os.unlink(self.input_doc.original_file)
self._fail( self._fail(
ConsumerStatusShortMessage.DOCUMENT_ALREADY_EXISTS, ConsumerStatusShortMessage.DOCUMENT_ALREADY_EXISTS,
f"Not consuming {self.filename}: It is a duplicate of" f"Not consuming {self.filename}: It is a duplicate of"
@ -362,26 +342,26 @@ class ConsumerPlugin(
""" """
Check that if override_asn is given, it is unique and within a valid range Check that if override_asn is given, it is unique and within a valid range
""" """
if not self.override_asn: if not self.metadata.asn:
# check not necessary in case no ASN gets set # check not necessary in case no ASN gets set
return return
# Validate the range is above zero and less than uint32_t max # Validate the range is above zero and less than uint32_t max
# otherwise, Whoosh can't handle it in the index # otherwise, Whoosh can't handle it in the index
if ( if (
self.override_asn < Document.ARCHIVE_SERIAL_NUMBER_MIN self.metadata.asn < Document.ARCHIVE_SERIAL_NUMBER_MIN
or self.override_asn > Document.ARCHIVE_SERIAL_NUMBER_MAX or self.metadata.asn > Document.ARCHIVE_SERIAL_NUMBER_MAX
): ):
self._fail( self._fail(
ConsumerStatusShortMessage.ASN_RANGE, ConsumerStatusShortMessage.ASN_RANGE,
f"Not consuming {self.filename}: " f"Not consuming {self.filename}: "
f"Given ASN {self.override_asn} is out of range " f"Given ASN {self.metadata.asn} is out of range "
f"[{Document.ARCHIVE_SERIAL_NUMBER_MIN:,}, " f"[{Document.ARCHIVE_SERIAL_NUMBER_MIN:,}, "
f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}]", f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}]",
) )
if Document.objects.filter(archive_serial_number=self.override_asn).exists(): if Document.objects.filter(archive_serial_number=self.metadata.asn).exists():
self._fail( self._fail(
ConsumerStatusShortMessage.ASN_ALREADY_EXISTS, ConsumerStatusShortMessage.ASN_ALREADY_EXISTS,
f"Not consuming {self.filename}: Given ASN {self.override_asn} already exists!", f"Not consuming {self.filename}: Given ASN {self.metadata.asn} already exists!",
) )
def run_pre_consume_script(self): def run_pre_consume_script(self):
@ -402,7 +382,7 @@ class ConsumerPlugin(
self.log.info(f"Executing pre-consume script {settings.PRE_CONSUME_SCRIPT}") self.log.info(f"Executing pre-consume script {settings.PRE_CONSUME_SCRIPT}")
working_file_path = str(self.working_copy) working_file_path = str(self.working_copy)
original_file_path = str(self.original_path) original_file_path = str(self.input_doc.original_file)
script_env = os.environ.copy() script_env = os.environ.copy()
script_env["DOCUMENT_SOURCE_PATH"] = original_file_path script_env["DOCUMENT_SOURCE_PATH"] = original_file_path
@ -508,7 +488,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
0, 0,
100, 100,
ConsumerFilePhase.STARTED, ProgressStatusOptions.STARTED,
ConsumerStatusShortMessage.NEW_FILE, ConsumerStatusShortMessage.NEW_FILE,
) )
@ -527,7 +507,7 @@ class ConsumerPlugin(
dir=settings.SCRATCH_DIR, dir=settings.SCRATCH_DIR,
) )
self.working_copy = Path(tempdir.name) / Path(self.filename) self.working_copy = Path(tempdir.name) / Path(self.filename)
copy_file_with_basic_stats(self.original_path, self.working_copy) copy_file_with_basic_stats(self.input_doc.original_file, self.working_copy)
# Determine the parser class. # Determine the parser class.
@ -559,7 +539,7 @@ class ConsumerPlugin(
def progress_callback(current_progress, max_progress): # pragma: no cover def progress_callback(current_progress, max_progress): # pragma: no cover
# recalculate progress to be within 20 and 80 # recalculate progress to be within 20 and 80
p = int((current_progress / max_progress) * 50 + 20) p = int((current_progress / max_progress) * 50 + 20)
self._send_progress(p, 100, ConsumerFilePhase.WORKING) self._send_progress(p, 100, ProgressStatusOptions.WORKING)
# This doesn't parse the document yet, but gives us a parser. # This doesn't parse the document yet, but gives us a parser.
@ -581,7 +561,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
20, 20,
100, 100,
ConsumerFilePhase.WORKING, ProgressStatusOptions.WORKING,
ConsumerStatusShortMessage.PARSING_DOCUMENT, ConsumerStatusShortMessage.PARSING_DOCUMENT,
) )
self.log.debug(f"Parsing {self.filename}...") self.log.debug(f"Parsing {self.filename}...")
@ -591,7 +571,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
70, 70,
100, 100,
ConsumerFilePhase.WORKING, ProgressStatusOptions.WORKING,
ConsumerStatusShortMessage.GENERATING_THUMBNAIL, ConsumerStatusShortMessage.GENERATING_THUMBNAIL,
) )
thumbnail = document_parser.get_thumbnail( thumbnail = document_parser.get_thumbnail(
@ -606,7 +586,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
90, 90,
100, 100,
ConsumerFilePhase.WORKING, ProgressStatusOptions.WORKING,
ConsumerStatusShortMessage.PARSE_DATE, ConsumerStatusShortMessage.PARSE_DATE,
) )
date = parse_date(self.filename, text) date = parse_date(self.filename, text)
@ -640,7 +620,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
95, 95,
100, 100,
ConsumerFilePhase.WORKING, ProgressStatusOptions.WORKING,
ConsumerStatusShortMessage.SAVE_DOCUMENT, ConsumerStatusShortMessage.SAVE_DOCUMENT,
) )
# now that everything is done, we can start to store the document # now that everything is done, we can start to store the document
@ -702,13 +682,13 @@ class ConsumerPlugin(
# Delete the file only if it was successfully consumed # Delete the file only if it was successfully consumed
self.log.debug(f"Deleting file {self.working_copy}") self.log.debug(f"Deleting file {self.working_copy}")
self.original_path.unlink() self.input_doc.original_file.unlink()
self.working_copy.unlink() self.working_copy.unlink()
# https://github.com/jonaswinkler/paperless-ng/discussions/1037 # https://github.com/jonaswinkler/paperless-ng/discussions/1037
shadow_file = os.path.join( shadow_file = os.path.join(
os.path.dirname(self.original_path), os.path.dirname(self.input_doc.original_file),
"._" + os.path.basename(self.original_path), "._" + os.path.basename(self.input_doc.original_file),
) )
if os.path.isfile(shadow_file): if os.path.isfile(shadow_file):
@ -734,7 +714,7 @@ class ConsumerPlugin(
self._send_progress( self._send_progress(
100, 100,
100, 100,
ConsumerFilePhase.SUCCESS, ProgressStatusOptions.SUCCESS,
ConsumerStatusShortMessage.FINISHED, ConsumerStatusShortMessage.FINISHED,
document.id, document.id,
) )
@ -748,18 +728,18 @@ class ConsumerPlugin(
local_added = timezone.localtime(timezone.now()) local_added = timezone.localtime(timezone.now())
correspondent_name = ( correspondent_name = (
Correspondent.objects.get(pk=self.override_correspondent_id).name Correspondent.objects.get(pk=self.metadata.correspondent_id).name
if self.override_correspondent_id is not None if self.metadata.correspondent_id is not None
else None else None
) )
doc_type_name = ( doc_type_name = (
DocumentType.objects.get(pk=self.override_document_type_id).name DocumentType.objects.get(pk=self.metadata.document_type_id).name
if self.override_document_type_id is not None if self.metadata.document_type_id is not None
else None else None
) )
owner_username = ( owner_username = (
User.objects.get(pk=self.override_owner_id).username User.objects.get(pk=self.metadata.owner_id).username
if self.override_owner_id is not None if self.metadata.owner_id is not None
else None else None
) )
@ -784,8 +764,8 @@ class ConsumerPlugin(
self.log.debug("Saving record to database") self.log.debug("Saving record to database")
if self.override_created is not None: if self.metadata.created is not None:
create_date = self.override_created create_date = self.metadata.created
self.log.debug( self.log.debug(
f"Creation date from post_documents parameter: {create_date}", f"Creation date from post_documents parameter: {create_date}",
) )
@ -796,7 +776,7 @@ class ConsumerPlugin(
create_date = date create_date = date
self.log.debug(f"Creation date from parse_date: {create_date}") self.log.debug(f"Creation date from parse_date: {create_date}")
else: else:
stats = os.stat(self.original_path) stats = os.stat(self.input_doc.original_file)
create_date = timezone.make_aware( create_date = timezone.make_aware(
datetime.datetime.fromtimestamp(stats.st_mtime), datetime.datetime.fromtimestamp(stats.st_mtime),
) )
@ -805,12 +785,12 @@ class ConsumerPlugin(
storage_type = Document.STORAGE_TYPE_UNENCRYPTED storage_type = Document.STORAGE_TYPE_UNENCRYPTED
title = file_info.title title = file_info.title
if self.override_title is not None: if self.metadata.title is not None:
try: try:
title = self._parse_title_placeholders(self.override_title) title = self._parse_title_placeholders(self.metadata.title)
except Exception as e: except Exception as e:
self.log.error( self.log.error(
f"Error occurred parsing title override '{self.override_title}', falling back to original. Exception: {e}", f"Error occurred parsing title override '{self.metadata.title}', falling back to original. Exception: {e}",
) )
document = Document.objects.create( document = Document.objects.create(
@ -831,53 +811,53 @@ class ConsumerPlugin(
return document return document
def apply_overrides(self, document): def apply_overrides(self, document):
if self.override_correspondent_id: if self.metadata.correspondent_id:
document.correspondent = Correspondent.objects.get( document.correspondent = Correspondent.objects.get(
pk=self.override_correspondent_id, pk=self.metadata.correspondent_id,
) )
if self.override_document_type_id: if self.metadata.document_type_id:
document.document_type = DocumentType.objects.get( document.document_type = DocumentType.objects.get(
pk=self.override_document_type_id, pk=self.metadata.document_type_id,
) )
if self.override_tag_ids: if self.metadata.tag_ids:
for tag_id in self.override_tag_ids: for tag_id in self.metadata.tag_ids:
document.tags.add(Tag.objects.get(pk=tag_id)) document.tags.add(Tag.objects.get(pk=tag_id))
if self.override_storage_path_id: if self.metadata.storage_path_id:
document.storage_path = StoragePath.objects.get( document.storage_path = StoragePath.objects.get(
pk=self.override_storage_path_id, pk=self.metadata.storage_path_id,
) )
if self.override_asn: if self.metadata.asn:
document.archive_serial_number = self.override_asn document.archive_serial_number = self.metadata.asn
if self.override_owner_id: if self.metadata.owner_id:
document.owner = User.objects.get( document.owner = User.objects.get(
pk=self.override_owner_id, pk=self.metadata.owner_id,
) )
if ( if (
self.override_view_users is not None self.metadata.view_users is not None
or self.override_view_groups is not None or self.metadata.view_groups is not None
or self.override_change_users is not None or self.metadata.change_users is not None
or self.override_change_users is not None or self.metadata.change_users is not None
): ):
permissions = { permissions = {
"view": { "view": {
"users": self.override_view_users or [], "users": self.metadata.view_users or [],
"groups": self.override_view_groups or [], "groups": self.metadata.view_groups or [],
}, },
"change": { "change": {
"users": self.override_change_users or [], "users": self.metadata.change_users or [],
"groups": self.override_change_groups or [], "groups": self.metadata.change_groups or [],
}, },
} }
set_permissions_for_object(permissions=permissions, object=document) set_permissions_for_object(permissions=permissions, object=document)
if self.override_custom_field_ids: if self.metadata.custom_field_ids:
for field_id in self.override_custom_field_ids: for field_id in self.metadata.custom_field_ids:
field = CustomField.objects.get(pk=field_id) field = CustomField.objects.get(pk=field_id)
CustomFieldInstance.objects.create( CustomFieldInstance.objects.create(
field=field, field=field,

View File

@ -20,7 +20,6 @@ from django.utils import timezone
from guardian.core import ObjectPermissionChecker from guardian.core import ObjectPermissionChecker
from documents.consumer import ConsumerError from documents.consumer import ConsumerError
from documents.consumer import ConsumerFilePhase
from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentMetadataOverrides
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import CustomField from documents.models import CustomField
@ -31,6 +30,7 @@ from documents.models import StoragePath
from documents.models import Tag from documents.models import Tag
from documents.parsers import DocumentParser from documents.parsers import DocumentParser
from documents.parsers import ParseError from documents.parsers import ParseError
from documents.plugins.helpers import ProgressStatusOptions
from documents.tasks import sanity_check from documents.tasks import sanity_check
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import FileSystemAssertsMixin from documents.tests.utils import FileSystemAssertsMixin
@ -256,8 +256,8 @@ class TestConsumer(
): ):
def _assert_first_last_send_progress( def _assert_first_last_send_progress(
self, self,
first_status=ConsumerFilePhase.STARTED, first_status=ProgressStatusOptions.STARTED,
last_status=ConsumerFilePhase.SUCCESS, last_status=ProgressStatusOptions.SUCCESS,
first_progress=0, first_progress=0,
first_progress_max=100, first_progress_max=100,
last_progress=100, last_progress=100,
@ -1077,7 +1077,7 @@ class PreConsumeTestCase(DirectoriesMixin, GetConsumerMixin, TestCase):
self.assertEqual(command[1], str(self.test_file)) self.assertEqual(command[1], str(self.test_file))
subset = { subset = {
"DOCUMENT_SOURCE_PATH": str(c.original_path), "DOCUMENT_SOURCE_PATH": str(c.input_doc.original_file),
"DOCUMENT_WORKING_PATH": str(c.working_copy), "DOCUMENT_WORKING_PATH": str(c.working_copy),
"TASK_ID": c.task_id, "TASK_ID": c.task_id,
} }