diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index 83be5eea9..9698f65cf 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -17,6 +17,7 @@ from documents.data_models import ConsumableDocument from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentSource from documents.models import Correspondent +from documents.models import CustomField from documents.models import CustomFieldInstance from documents.models import Document from documents.models import DocumentType @@ -147,17 +148,34 @@ def modify_tags( def modify_custom_fields( doc_ids: list[int], - add_custom_fields, - remove_custom_fields, + add_custom_fields: list[int] | dict, + remove_custom_fields: list[int], ) -> Literal["OK"]: qs = Document.objects.filter(id__in=doc_ids).only("pk") affected_docs = list(qs.values_list("pk", flat=True)) + # Ensure add_custom_fields is a list of tuples, supports old API + add_custom_fields = ( + add_custom_fields.items() + if isinstance(add_custom_fields, dict) + else [(field, None) for field in add_custom_fields] + ) - for field in add_custom_fields: + custom_fields = CustomField.objects.filter( + id__in=[int(field) for field, _ in add_custom_fields], + ).distinct() + for field_id, value in add_custom_fields: for doc_id in affected_docs: + defaults = {} + custom_field = custom_fields.get(id=field_id) + if custom_field: + value_field = CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[ + custom_field.data_type + ] + defaults[value_field] = value CustomFieldInstance.objects.update_or_create( document_id=doc_id, - field_id=field, + field_id=field_id, + defaults=defaults, ) CustomFieldInstance.objects.filter( document_id__in=affected_docs, diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 91e291c21..09cfe4e01 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1140,13 +1140,27 @@ class BulkEditSerializer( f"Some tags in {name} don't exist or were specified twice.", ) - def _validate_custom_field_id_list(self, custom_fields, name="custom_fields"): - if not isinstance(custom_fields, list): - raise serializers.ValidationError(f"{name} must be a list") - if not all(isinstance(i, int) for i in custom_fields): - raise serializers.ValidationError(f"{name} must be a list of integers") - count = CustomField.objects.filter(id__in=custom_fields).count() - if not count == len(custom_fields): + def _validate_custom_field_id_list_or_dict( + self, + custom_fields, + name="custom_fields", + ): + ids = custom_fields + if isinstance(custom_fields, dict): + try: + ids = [int(i[0]) for i in custom_fields.items()] + except Exception as e: + raise serializers.ValidationError( + f"{name} must be a list of integers or a dict of key-value pairs: {e}", + ) + elif not isinstance(custom_fields, list) or not all( + isinstance(i, int) for i in ids + ): + raise serializers.ValidationError( + f"{name} must be a list of integers or a dict of key-value pairs", + ) + count = CustomField.objects.filter(id__in=ids).count() + if not count == len(ids): raise serializers.ValidationError( f"Some custom fields in {name} don't exist or were specified twice.", ) @@ -1245,7 +1259,7 @@ class BulkEditSerializer( def _validate_parameters_modify_custom_fields(self, parameters): if "add_custom_fields" in parameters: - self._validate_custom_field_id_list( + self._validate_custom_field_id_list_or_dict( parameters["add_custom_fields"], "add_custom_fields", ) @@ -1253,7 +1267,7 @@ class BulkEditSerializer( raise serializers.ValidationError("add_custom_fields not specified") if "remove_custom_fields" in parameters: - self._validate_custom_field_id_list( + self._validate_custom_field_id_list_or_dict( parameters["remove_custom_fields"], "remove_custom_fields", ) diff --git a/src/documents/tests/test_api_bulk_edit.py b/src/documents/tests/test_api_bulk_edit.py index 075bbfd6a..59002802c 100644 --- a/src/documents/tests/test_api_bulk_edit.py +++ b/src/documents/tests/test_api_bulk_edit.py @@ -244,7 +244,9 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): "documents": [self.doc1.id, self.doc3.id], "method": "modify_custom_fields", "parameters": { - "add_custom_fields": [self.cf1.id], + "add_custom_fields": [ + self.cf1.id, + ], # old format accepts list of IDs "remove_custom_fields": [self.cf2.id], }, }, @@ -258,6 +260,30 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase): self.assertEqual(kwargs["add_custom_fields"], [self.cf1.id]) self.assertEqual(kwargs["remove_custom_fields"], [self.cf2.id]) + @mock.patch("documents.serialisers.bulk_edit.modify_custom_fields") + def test_api_modify_custom_fields_with_values(self, m): + self.setup_mock(m, "modify_custom_fields") + response = self.client.post( + "/api/documents/bulk_edit/", + json.dumps( + { + "documents": [self.doc1.id, self.doc3.id], + "method": "modify_custom_fields", + "parameters": { + "add_custom_fields": {self.cf1.id: "foo"}, + "remove_custom_fields": [self.cf2.id], + }, + }, + ), + content_type="application/json", + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + m.assert_called_once() + args, kwargs = m.call_args + self.assertListEqual(args[0], [self.doc1.id, self.doc3.id]) + self.assertEqual(kwargs["add_custom_fields"], {str(self.cf1.id): "foo"}) + self.assertEqual(kwargs["remove_custom_fields"], [self.cf2.id]) + @mock.patch("documents.serialisers.bulk_edit.modify_custom_fields") def test_api_modify_custom_fields_invalid_params(self, m): """ diff --git a/src/documents/tests/test_bulk_edit.py b/src/documents/tests/test_bulk_edit.py index bb5ebf04d..00a72845f 100644 --- a/src/documents/tests/test_bulk_edit.py +++ b/src/documents/tests/test_bulk_edit.py @@ -189,6 +189,15 @@ class TestBulkEdit(DirectoriesMixin, TestCase): self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id]) def test_modify_custom_fields(self): + """ + GIVEN: + - 2 documents with custom fields + - 3 custom fields + WHEN: + - Custom fields are modified using old format (list of ids) + THEN: + - Custom fields are modified for the documents + """ cf = CustomField.objects.create( name="cf1", data_type=CustomField.FieldDataType.STRING, @@ -235,6 +244,78 @@ class TestBulkEdit(DirectoriesMixin, TestCase): args, kwargs = self.async_task.call_args self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) + def test_modify_custom_fields_with_values(self): + """ + GIVEN: + - 2 documents with custom fields + - 3 custom fields + WHEN: + - Custom fields are modified using new format (dict) + THEN: + - Custom fields are modified for the documents + """ + cf = CustomField.objects.create( + name="cf", + data_type=CustomField.FieldDataType.STRING, + ) + cf1 = CustomField.objects.create( + name="cf1", + data_type=CustomField.FieldDataType.STRING, + ) + cf2 = CustomField.objects.create( + name="cf2", + data_type=CustomField.FieldDataType.INT, + ) + cf3 = CustomField.objects.create( + name="cf3", + data_type=CustomField.FieldDataType.STRING, + ) + CustomFieldInstance.objects.create( + document=self.doc2, + field=cf, + ) + CustomFieldInstance.objects.create( + document=self.doc2, + field=cf1, + ) + CustomFieldInstance.objects.create( + document=self.doc2, + field=cf3, + ) + bulk_edit.modify_custom_fields( + [self.doc1.id, self.doc2.id], + add_custom_fields={cf2.id: None, cf3.id: "value"}, + remove_custom_fields=[cf.id], + ) + + self.doc1.refresh_from_db() + self.doc2.refresh_from_db() + + self.assertEqual( + self.doc1.custom_fields.count(), + 2, + ) + self.assertEqual( + self.doc1.custom_fields.get(field=cf2).value, + None, + ) + self.assertEqual( + self.doc1.custom_fields.get(field=cf3).value, + "value", + ) + self.assertEqual( + self.doc2.custom_fields.count(), + 3, + ) + self.assertEqual( + self.doc2.custom_fields.get(field=cf3).value, + "value", + ) + + self.async_task.assert_called_once() + args, kwargs = self.async_task.call_args + self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) + def test_delete(self): self.assertEqual(Document.objects.count(), 5) bulk_edit.delete([self.doc1.id, self.doc2.id])