Add profile API view

This commit is contained in:
shamoon 2023-11-20 22:59:11 -08:00
parent 5a20c8e512
commit 47916c8b0b
5 changed files with 121 additions and 0 deletions

View File

@ -21,6 +21,7 @@ The API provides the following main endpoints:
- `/api/groups/`: Full CRUD support. - `/api/groups/`: Full CRUD support.
- `/api/share_links/`: Full CRUD support. - `/api/share_links/`: Full CRUD support.
- `/api/custom_fields/`: Full CRUD support. - `/api/custom_fields/`: Full CRUD support.
- `/api/profile/`: GET, PATCH
All of these endpoints except for the logging endpoint allow you to All of these endpoints except for the logging endpoint allow you to
fetch (and edit and delete where appropriate) individual objects by fetch (and edit and delete where appropriate) individual objects by

View File

@ -5793,3 +5793,62 @@ class TestApiConsumptionTemplates(DirectoriesMixin, APITestCase):
self.assertEqual(ConsumptionTemplate.objects.count(), 2) self.assertEqual(ConsumptionTemplate.objects.count(), 2)
ct = ConsumptionTemplate.objects.get(name="Template 2") ct = ConsumptionTemplate.objects.get(name="Template 2")
self.assertEqual(ct.sources, [int(DocumentSource.MailFetch).__str__()]) self.assertEqual(ct.sources, [int(DocumentSource.MailFetch).__str__()])
class TestApiProfile(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/profile/"
def setUp(self):
super().setUp()
self.user = User.objects.create_superuser(
username="temp_admin",
first_name="firstname",
last_name="surname",
)
self.client.force_authenticate(user=self.user)
def test_get_profile(self):
"""
GIVEN:
- Configured user
WHEN:
- API call is made to get profile
THEN:
- Profile is returned
"""
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["email"], self.user.email)
self.assertEqual(response.data["password"], "**********")
self.assertEqual(response.data["first_name"], self.user.first_name)
self.assertEqual(response.data["last_name"], self.user.last_name)
def test_update_profile(self):
"""
GIVEN:
- Configured user
WHEN:
- API call is made to update profile
THEN:
- Profile is updated
"""
user_data = {
"email": "new@email.com",
"password": "superpassword1234",
"first_name": "new first name",
"last_name": "new last name",
}
response = self.client.patch(self.ENDPOINT, user_data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
user = User.objects.get(username=self.user.username)
self.assertTrue(user.check_password(user_data["password"]))
self.assertEqual(user.email, user_data["email"])
self.assertEqual(user.first_name, user_data["first_name"])
self.assertEqual(user.last_name, user_data["last_name"])

View File

@ -97,3 +97,17 @@ class GroupSerializer(serializers.ModelSerializer):
"name", "name",
"permissions", "permissions",
) )
class ProfileSerializer(serializers.ModelSerializer):
email = serializers.EmailField(allow_null=False)
password = ObfuscatedUserPasswordField(required=False)
class Meta:
model = User
fields = (
"email",
"password",
"first_name",
"last_name",
)

View File

@ -36,6 +36,7 @@ from documents.views import UnifiedSearchViewSet
from paperless.consumers import StatusConsumer from paperless.consumers import StatusConsumer
from paperless.views import FaviconView from paperless.views import FaviconView
from paperless.views import GroupViewSet from paperless.views import GroupViewSet
from paperless.views import ProfileView
from paperless.views import UserViewSet from paperless.views import UserViewSet
from paperless_mail.views import MailAccountTestView from paperless_mail.views import MailAccountTestView
from paperless_mail.views import MailAccountViewSet from paperless_mail.views import MailAccountViewSet
@ -119,6 +120,11 @@ urlpatterns = [
BulkEditObjectPermissionsView.as_view(), BulkEditObjectPermissionsView.as_view(),
name="bulk_edit_object_permissions", name="bulk_edit_object_permissions",
), ),
re_path(
"^profile/",
ProfileView.as_view(),
name="profile_view",
),
*api_router.urls, *api_router.urls,
], ],
), ),

View File

@ -8,6 +8,7 @@ from django.http import HttpResponse
from django.views.generic import View from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.filters import OrderingFilter from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView
from rest_framework.pagination import PageNumberPagination from rest_framework.pagination import PageNumberPagination
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
@ -17,6 +18,7 @@ from documents.permissions import PaperlessObjectPermissions
from paperless.filters import GroupFilterSet from paperless.filters import GroupFilterSet
from paperless.filters import UserFilterSet from paperless.filters import UserFilterSet
from paperless.serialisers import GroupSerializer from paperless.serialisers import GroupSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer from paperless.serialisers import UserSerializer
@ -106,3 +108,42 @@ class GroupViewSet(ModelViewSet):
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filterset_class = GroupFilterSet filterset_class = GroupFilterSet
ordering_fields = ("name",) ordering_fields = ("name",)
class ProfileView(GenericAPIView):
permission_classes = [IsAuthenticated]
serializer_class = ProfileSerializer
def get(self, request, *args, **kwargs):
user = self.request.user if hasattr(self.request, "user") else None
return Response(
{
"email": user.email,
"password": "**********",
"first_name": user.first_name,
"last_name": user.last_name,
},
)
def patch(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
user = self.request.user if hasattr(self.request, "user") else None
if len(serializer.validated_data.get("password").replace("*", "")) > 0:
user.set_password(serializer.validated_data.get("password"))
user.save()
serializer.validated_data.pop("password")
for key, value in serializer.validated_data.items():
setattr(user, key, value)
user.save()
return Response(
{
"email": user.email,
"password": "**********",
"first_name": user.first_name,
"last_name": user.last_name,
},
)