diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index 321a34ccb..47ee2ac85 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -960,3 +960,82 @@ class ShareLinkSerializer(OwnedObjectSerializer): def create(self, validated_data): validated_data["slug"] = get_random_string(50) return super().create(validated_data) + + +class BulkEditObjectPermissionsSerializer(serializers.Serializer, SetPermissionsMixin): + objects = serializers.ListField( + required=True, + allow_empty=False, + label="Objects", + write_only=True, + child=serializers.IntegerField(), + ) + + object_type = serializers.ChoiceField( + choices=[ + "tag", + "correspondent", + "document_type", + "storage_path", + ], + label="Object Type", + write_only=True, + ) + + owner = serializers.PrimaryKeyRelatedField( + queryset=User.objects.all(), + required=False, + allow_null=True, + ) + + permissions = serializers.DictField( + label="Set permissions", + allow_empty=False, + required=False, + write_only=True, + ) + + def get_object_class(self, object_type): + object_class = None + if object_type == "tag": + object_class = Tag + elif object_type == "correspondent": + object_class = Correspondent + elif object_type == "document_type": + object_class = DocumentType + elif object_type == "storage_path": + object_class = StoragePath + return object_class + + def _validate_objects(self, objects, object_type): + if not isinstance(objects, list): + raise serializers.ValidationError("objects must be a list") + if not all(isinstance(i, int) for i in objects): + raise serializers.ValidationError("objects must be a list of integers") + object_class = self.get_object_class(object_type) + if object_class is None: + raise serializers.ValidationError( + "Unknown object type.", + ) + count = object_class.objects.filter(id__in=objects).count() + if not count == len(objects): + raise serializers.ValidationError( + "Some ids in objects don't exist or were specified twice.", + ) + return objects + + def _validate_permissions(self, permissions): + self.validate_set_permissions( + permissions, + ) + + def validate(self, attrs): + object_type = attrs["object_type"] + objects = attrs["objects"] + permissions = attrs["permissions"] if "permissions" in attrs else None + + self._validate_objects(objects, object_type) + if permissions is not None: + self._validate_permissions(permissions) + + return attrs diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index f9c6da0a8..b390cf86b 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -25,6 +25,7 @@ from django.test import override_settings from django.utils import timezone from guardian.shortcuts import assign_perm from guardian.shortcuts import get_perms +from guardian.shortcuts import get_users_with_perms from rest_framework import status from rest_framework.test import APITestCase from whoosh.writing import AsyncWriter @@ -5088,3 +5089,224 @@ class TestApiGroup(DirectoriesMixin, APITestCase): returned_group1 = Group.objects.get(pk=group1.pk) self.assertEqual(returned_group1.name, "Updated Name 1") + + +class TestBulkEditObjectPermissions(APITestCase): + def setUp(self): + super().setUp() + + user = User.objects.create_superuser(username="temp_admin") + self.client.force_authenticate(user=user) + + self.t1 = Tag.objects.create(name="t1") + self.t2 = Tag.objects.create(name="t2") + self.user1 = User.objects.create(username="user1") + self.user2 = User.objects.create(username="user2") + self.user3 = User.objects.create(username="user3") + + def test_bulk_object_set_permissions(self): + """ + GIVEN: + - Existing objects + WHEN: + - bulk_edit_object_perms API endpoint is called + THEN: + - Permissions and / or owner are changed + """ + permissions = { + "view": { + "users": [self.user1.id, self.user2.id], + "groups": [], + }, + "change": { + "users": [self.user1.id], + "groups": [], + }, + } + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.t1.id, self.t2.id], + "object_type": "tag", + "permissions": permissions, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn(self.user1, get_users_with_perms(self.t1)) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.c1.id], + "object_type": "correspondents", + "permissions": permissions, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn(self.user1, get_users_with_perms(self.c1)) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.dt1.id], + "object_type": "document_types", + "permissions": permissions, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn(self.user1, get_users_with_perms(self.dt1)) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.sp1.id], + "object_type": "storage_paths", + "permissions": permissions, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn(self.user1, get_users_with_perms(self.sp1)) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.t1.id, self.t2.id], + "object_type": "tag", + "owner": self.user3.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(Tag.objects.get(pk=self.t2.id).owner, self.user3) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.sp1.id], + "object_type": "storage_paths", + "owner": self.user3.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(StoragePath.objects.get(pk=self.sp1.id).owner, self.user3) + + def test_bulk_edit_object_permissions_insufficient_perms(self): + """ + GIVEN: + - Objects owned by user other than logged in user + WHEN: + - bulk_edit_object_perms API endpoint is called + THEN: + - User is not able to change permissions + """ + self.t1.owner = User.objects.get(username="temp_admin") + self.t1.save() + self.client.force_authenticate(user=self.user1) + + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.t1.id, self.t2.id], + "object_type": "tag", + "owner": self.user1.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.content, b"Insufficient permissions") + + def test_bulk_edit_object_permissions_validation(self): + """ + GIVEN: + - Existing objects + WHEN: + - bulk_edit_object_perms API endpoint is called with invalid params + THEN: + - Validation fails + """ + # not a list + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": self.t1.id, + "object_type": "tags", + "owner": self.user1.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # not a list of ints + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": ["one"], + "object_type": "tags", + "owner": self.user1.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # duplicates + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [self.t1.id, self.t2.id, self.t1.id], + "object_type": "tags", + "owner": self.user1.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + # not a valid object type + response = self.client.post( + "/api/bulk_edit_object_perms/", + json.dumps( + { + "objects": [1], + "object_type": "madeup", + "owner": self.user1.id, + }, + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/src/documents/views.py b/src/documents/views.py index 856b27e27..be6ce1ff7 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -63,6 +63,7 @@ from documents.permissions import PaperlessAdminPermissions from documents.permissions import PaperlessObjectPermissions from documents.permissions import get_objects_for_user_owner_aware from documents.permissions import has_perms_owner_aware +from documents.permissions import set_permissions_for_object from documents.tasks import consume_file from paperless import version from paperless.db import GnuPG @@ -98,6 +99,7 @@ from .parsers import get_parser_class_for_mime_type from .parsers import parse_date_generator from .serialisers import AcknowledgeTasksViewSerializer from .serialisers import BulkDownloadSerializer +from .serialisers import BulkEditObjectPermissionsSerializer from .serialisers import BulkEditSerializer from .serialisers import CorrespondentSerializer from .serialisers import DocumentListSerializer @@ -1205,3 +1207,44 @@ def serve_file(doc: Document, use_archive: bool, disposition: str): ) response["Content-Disposition"] = content_disposition return response + + +class BulkEditObjectPermissionsView(GenericAPIView, PassUserMixin): + permission_classes = (IsAuthenticated,) + serializer_class = BulkEditObjectPermissionsSerializer + parser_classes = (parsers.JSONParser,) + + def post(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + user = self.request.user + object_type = serializer.validated_data.get("object_type") + object_ids = serializer.validated_data.get("objects") + object_class = serializer.get_object_class(object_type) + permissions = serializer.validated_data.get("permissions") + owner = serializer.validated_data.get("owner") + + if not user.is_superuser: + objs = object_class.objects.filter(pk__in=object_ids) + has_perms = all((obj.owner == user or obj.owner is None) for obj in objs) + + if not has_perms: + return HttpResponseForbidden("Insufficient permissions") + + try: + qs = object_class.objects.filter(id__in=object_ids) + + if "owner" in serializer.validated_data: + qs.update(owner=owner) + + if "permissions" in serializer.validated_data: + for obj in qs: + set_permissions_for_object(permissions, obj) + + return Response({"result": "OK"}) + except Exception as e: + logger.warning(f"An error occurred performing bulk permissions edit: {e!s}") + return HttpResponseBadRequest( + "Error performing bulk permissions edit, check logs for more detail.", + ) diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 5d24478aa..05e772ee0 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -12,6 +12,7 @@ from rest_framework.routers import DefaultRouter from documents.views import AcknowledgeTasksView from documents.views import BulkDownloadView +from documents.views import BulkEditObjectPermissionsView from documents.views import BulkEditView from documents.views import CorrespondentViewSet from documents.views import DocumentTypeViewSet @@ -109,6 +110,11 @@ urlpatterns = [ name="mail_accounts_test", ), path("token/", views.obtain_auth_token), + re_path( + "^bulk_edit_object_perms/", + BulkEditObjectPermissionsView.as_view(), + name="bulk_edit_object_permissions", + ), *api_router.urls, ], ),