diff --git a/docs/api.md b/docs/api.md index 0eacd7913..cbf94cf6a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -21,6 +21,7 @@ The API provides the following main endpoints: - `/api/groups/`: Full CRUD support. - `/api/share_links/`: Full CRUD support. - `/api/custom_fields/`: Full CRUD support. +- `/api/profile/`: GET, PATCH All of these endpoints except for the logging endpoint allow you to fetch (and edit and delete where appropriate) individual objects by diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 9000c3c21..71bfdc5ff 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -5793,3 +5793,62 @@ class TestApiConsumptionTemplates(DirectoriesMixin, APITestCase): self.assertEqual(ConsumptionTemplate.objects.count(), 2) ct = ConsumptionTemplate.objects.get(name="Template 2") self.assertEqual(ct.sources, [int(DocumentSource.MailFetch).__str__()]) + + +class TestApiProfile(DirectoriesMixin, APITestCase): + ENDPOINT = "/api/profile/" + + def setUp(self): + super().setUp() + + self.user = User.objects.create_superuser( + username="temp_admin", + first_name="firstname", + last_name="surname", + ) + self.client.force_authenticate(user=self.user) + + def test_get_profile(self): + """ + GIVEN: + - Configured user + WHEN: + - API call is made to get profile + THEN: + - Profile is returned + """ + + response = self.client.get(self.ENDPOINT) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertEqual(response.data["email"], self.user.email) + self.assertEqual(response.data["password"], "**********") + self.assertEqual(response.data["first_name"], self.user.first_name) + self.assertEqual(response.data["last_name"], self.user.last_name) + + def test_update_profile(self): + """ + GIVEN: + - Configured user + WHEN: + - API call is made to update profile + THEN: + - Profile is updated + """ + + user_data = { + "email": "new@email.com", + "password": "superpassword1234", + "first_name": "new first name", + "last_name": "new last name", + } + response = self.client.patch(self.ENDPOINT, user_data) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + user = User.objects.get(username=self.user.username) + self.assertTrue(user.check_password(user_data["password"])) + self.assertEqual(user.email, user_data["email"]) + self.assertEqual(user.first_name, user_data["first_name"]) + self.assertEqual(user.last_name, user_data["last_name"]) diff --git a/src/paperless/serialisers.py b/src/paperless/serialisers.py index 4094a6538..572722e13 100644 --- a/src/paperless/serialisers.py +++ b/src/paperless/serialisers.py @@ -97,3 +97,17 @@ class GroupSerializer(serializers.ModelSerializer): "name", "permissions", ) + + +class ProfileSerializer(serializers.ModelSerializer): + email = serializers.EmailField(allow_null=False) + password = ObfuscatedUserPasswordField(required=False) + + class Meta: + model = User + fields = ( + "email", + "password", + "first_name", + "last_name", + ) diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 2f0c56267..4371bc983 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -36,6 +36,7 @@ from documents.views import UnifiedSearchViewSet from paperless.consumers import StatusConsumer from paperless.views import FaviconView from paperless.views import GroupViewSet +from paperless.views import ProfileView from paperless.views import UserViewSet from paperless_mail.views import MailAccountTestView from paperless_mail.views import MailAccountViewSet @@ -119,6 +120,11 @@ urlpatterns = [ BulkEditObjectPermissionsView.as_view(), name="bulk_edit_object_permissions", ), + re_path( + "^profile/", + ProfileView.as_view(), + name="profile_view", + ), *api_router.urls, ], ), diff --git a/src/paperless/views.py b/src/paperless/views.py index e872cc19c..aac8dfac2 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -8,6 +8,7 @@ from django.http import HttpResponse from django.views.generic import View from django_filters.rest_framework import DjangoFilterBackend from rest_framework.filters import OrderingFilter +from rest_framework.generics import GenericAPIView from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -17,6 +18,7 @@ from documents.permissions import PaperlessObjectPermissions from paperless.filters import GroupFilterSet from paperless.filters import UserFilterSet from paperless.serialisers import GroupSerializer +from paperless.serialisers import ProfileSerializer from paperless.serialisers import UserSerializer @@ -106,3 +108,42 @@ class GroupViewSet(ModelViewSet): filter_backends = (DjangoFilterBackend, OrderingFilter) filterset_class = GroupFilterSet ordering_fields = ("name",) + + +class ProfileView(GenericAPIView): + permission_classes = [IsAuthenticated] + serializer_class = ProfileSerializer + + def get(self, request, *args, **kwargs): + user = self.request.user if hasattr(self.request, "user") else None + + return Response( + { + "email": user.email, + "password": "**********", + "first_name": user.first_name, + "last_name": user.last_name, + }, + ) + + def patch(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + user = self.request.user if hasattr(self.request, "user") else None + + if len(serializer.validated_data.get("password").replace("*", "")) > 0: + user.set_password(serializer.validated_data.get("password")) + user.save() + serializer.validated_data.pop("password") + + for key, value in serializer.validated_data.items(): + setattr(user, key, value) + user.save() + return Response( + { + "email": user.email, + "password": "**********", + "first_name": user.first_name, + "last_name": user.last_name, + }, + )