This commit is contained in:
shamoon 2024-12-03 16:25:26 -08:00
parent 9ed88d4acb
commit 9a452cc396
4 changed files with 153 additions and 14 deletions

View File

@ -17,6 +17,7 @@ from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides from documents.data_models import DocumentMetadataOverrides
from documents.data_models import DocumentSource from documents.data_models import DocumentSource
from documents.models import Correspondent from documents.models import Correspondent
from documents.models import CustomField
from documents.models import CustomFieldInstance from documents.models import CustomFieldInstance
from documents.models import Document from documents.models import Document
from documents.models import DocumentType from documents.models import DocumentType
@ -147,17 +148,34 @@ def modify_tags(
def modify_custom_fields( def modify_custom_fields(
doc_ids: list[int], doc_ids: list[int],
add_custom_fields, add_custom_fields: list[int] | dict,
remove_custom_fields, remove_custom_fields: list[int],
) -> Literal["OK"]: ) -> Literal["OK"]:
qs = Document.objects.filter(id__in=doc_ids).only("pk") qs = Document.objects.filter(id__in=doc_ids).only("pk")
affected_docs = list(qs.values_list("pk", flat=True)) affected_docs = list(qs.values_list("pk", flat=True))
# Ensure add_custom_fields is a list of tuples, supports old API
add_custom_fields = (
add_custom_fields.items()
if isinstance(add_custom_fields, dict)
else [(field, None) for field in add_custom_fields]
)
for field in add_custom_fields: custom_fields = CustomField.objects.filter(
id__in=[int(field) for field, _ in add_custom_fields],
).distinct()
for field_id, value in add_custom_fields:
for doc_id in affected_docs: for doc_id in affected_docs:
defaults = {}
custom_field = custom_fields.get(id=field_id)
if custom_field:
value_field = CustomFieldInstance.TYPE_TO_DATA_STORE_NAME_MAP[
custom_field.data_type
]
defaults[value_field] = value
CustomFieldInstance.objects.update_or_create( CustomFieldInstance.objects.update_or_create(
document_id=doc_id, document_id=doc_id,
field_id=field, field_id=field_id,
defaults=defaults,
) )
CustomFieldInstance.objects.filter( CustomFieldInstance.objects.filter(
document_id__in=affected_docs, document_id__in=affected_docs,

View File

@ -1140,13 +1140,27 @@ class BulkEditSerializer(
f"Some tags in {name} don't exist or were specified twice.", f"Some tags in {name} don't exist or were specified twice.",
) )
def _validate_custom_field_id_list(self, custom_fields, name="custom_fields"): def _validate_custom_field_id_list_or_dict(
if not isinstance(custom_fields, list): self,
raise serializers.ValidationError(f"{name} must be a list") custom_fields,
if not all(isinstance(i, int) for i in custom_fields): name="custom_fields",
raise serializers.ValidationError(f"{name} must be a list of integers") ):
count = CustomField.objects.filter(id__in=custom_fields).count() ids = custom_fields
if not count == len(custom_fields): if isinstance(custom_fields, dict):
try:
ids = [int(i[0]) for i in custom_fields.items()]
except Exception as e:
raise serializers.ValidationError(
f"{name} must be a list of integers or a dict of key-value pairs: {e}",
)
elif not isinstance(custom_fields, list) or not all(
isinstance(i, int) for i in ids
):
raise serializers.ValidationError(
f"{name} must be a list of integers or a dict of key-value pairs",
)
count = CustomField.objects.filter(id__in=ids).count()
if not count == len(ids):
raise serializers.ValidationError( raise serializers.ValidationError(
f"Some custom fields in {name} don't exist or were specified twice.", f"Some custom fields in {name} don't exist or were specified twice.",
) )
@ -1245,7 +1259,7 @@ class BulkEditSerializer(
def _validate_parameters_modify_custom_fields(self, parameters): def _validate_parameters_modify_custom_fields(self, parameters):
if "add_custom_fields" in parameters: if "add_custom_fields" in parameters:
self._validate_custom_field_id_list( self._validate_custom_field_id_list_or_dict(
parameters["add_custom_fields"], parameters["add_custom_fields"],
"add_custom_fields", "add_custom_fields",
) )
@ -1253,7 +1267,7 @@ class BulkEditSerializer(
raise serializers.ValidationError("add_custom_fields not specified") raise serializers.ValidationError("add_custom_fields not specified")
if "remove_custom_fields" in parameters: if "remove_custom_fields" in parameters:
self._validate_custom_field_id_list( self._validate_custom_field_id_list_or_dict(
parameters["remove_custom_fields"], parameters["remove_custom_fields"],
"remove_custom_fields", "remove_custom_fields",
) )

View File

@ -244,7 +244,9 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
"documents": [self.doc1.id, self.doc3.id], "documents": [self.doc1.id, self.doc3.id],
"method": "modify_custom_fields", "method": "modify_custom_fields",
"parameters": { "parameters": {
"add_custom_fields": [self.cf1.id], "add_custom_fields": [
self.cf1.id,
], # old format accepts list of IDs
"remove_custom_fields": [self.cf2.id], "remove_custom_fields": [self.cf2.id],
}, },
}, },
@ -258,6 +260,30 @@ class TestBulkEditAPI(DirectoriesMixin, APITestCase):
self.assertEqual(kwargs["add_custom_fields"], [self.cf1.id]) self.assertEqual(kwargs["add_custom_fields"], [self.cf1.id])
self.assertEqual(kwargs["remove_custom_fields"], [self.cf2.id]) self.assertEqual(kwargs["remove_custom_fields"], [self.cf2.id])
@mock.patch("documents.serialisers.bulk_edit.modify_custom_fields")
def test_api_modify_custom_fields_with_values(self, m):
self.setup_mock(m, "modify_custom_fields")
response = self.client.post(
"/api/documents/bulk_edit/",
json.dumps(
{
"documents": [self.doc1.id, self.doc3.id],
"method": "modify_custom_fields",
"parameters": {
"add_custom_fields": {self.cf1.id: "foo"},
"remove_custom_fields": [self.cf2.id],
},
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
m.assert_called_once()
args, kwargs = m.call_args
self.assertListEqual(args[0], [self.doc1.id, self.doc3.id])
self.assertEqual(kwargs["add_custom_fields"], {str(self.cf1.id): "foo"})
self.assertEqual(kwargs["remove_custom_fields"], [self.cf2.id])
@mock.patch("documents.serialisers.bulk_edit.modify_custom_fields") @mock.patch("documents.serialisers.bulk_edit.modify_custom_fields")
def test_api_modify_custom_fields_invalid_params(self, m): def test_api_modify_custom_fields_invalid_params(self, m):
""" """

View File

@ -189,6 +189,15 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id]) self.assertCountEqual(kwargs["document_ids"], [self.doc2.id, self.doc3.id])
def test_modify_custom_fields(self): def test_modify_custom_fields(self):
"""
GIVEN:
- 2 documents with custom fields
- 3 custom fields
WHEN:
- Custom fields are modified using old format (list of ids)
THEN:
- Custom fields are modified for the documents
"""
cf = CustomField.objects.create( cf = CustomField.objects.create(
name="cf1", name="cf1",
data_type=CustomField.FieldDataType.STRING, data_type=CustomField.FieldDataType.STRING,
@ -235,6 +244,78 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
args, kwargs = self.async_task.call_args args, kwargs = self.async_task.call_args
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id]) self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
def test_modify_custom_fields_with_values(self):
"""
GIVEN:
- 2 documents with custom fields
- 3 custom fields
WHEN:
- Custom fields are modified using new format (dict)
THEN:
- Custom fields are modified for the documents
"""
cf = CustomField.objects.create(
name="cf",
data_type=CustomField.FieldDataType.STRING,
)
cf1 = CustomField.objects.create(
name="cf1",
data_type=CustomField.FieldDataType.STRING,
)
cf2 = CustomField.objects.create(
name="cf2",
data_type=CustomField.FieldDataType.INT,
)
cf3 = CustomField.objects.create(
name="cf3",
data_type=CustomField.FieldDataType.STRING,
)
CustomFieldInstance.objects.create(
document=self.doc2,
field=cf,
)
CustomFieldInstance.objects.create(
document=self.doc2,
field=cf1,
)
CustomFieldInstance.objects.create(
document=self.doc2,
field=cf3,
)
bulk_edit.modify_custom_fields(
[self.doc1.id, self.doc2.id],
add_custom_fields={cf2.id: None, cf3.id: "value"},
remove_custom_fields=[cf.id],
)
self.doc1.refresh_from_db()
self.doc2.refresh_from_db()
self.assertEqual(
self.doc1.custom_fields.count(),
2,
)
self.assertEqual(
self.doc1.custom_fields.get(field=cf2).value,
None,
)
self.assertEqual(
self.doc1.custom_fields.get(field=cf3).value,
"value",
)
self.assertEqual(
self.doc2.custom_fields.count(),
3,
)
self.assertEqual(
self.doc2.custom_fields.get(field=cf3).value,
"value",
)
self.async_task.assert_called_once()
args, kwargs = self.async_task.call_args
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
def test_delete(self): def test_delete(self):
self.assertEqual(Document.objects.count(), 5) self.assertEqual(Document.objects.count(), 5)
bulk_edit.delete([self.doc1.id, self.doc2.id]) bulk_edit.delete([self.doc1.id, self.doc2.id])