diff --git a/src/documents/tests/test_api_profile.py b/src/documents/tests/test_api_profile.py index 1a47ca0c9..713c6829f 100644 --- a/src/documents/tests/test_api_profile.py +++ b/src/documents/tests/test_api_profile.py @@ -1,3 +1,5 @@ +from unittest import mock + from allauth.socialaccount.models import SocialAccount from allauth.socialaccount.models import SocialApp from django.contrib.auth.models import User @@ -137,6 +139,49 @@ class TestApiProfile(DirectoriesMixin, APITestCase): response.data[0]["login_url"], ) + @mock.patch( + "allauth.socialaccount.adapter.DefaultSocialAccountAdapter.list_providers", + ) + def test_get_social_account_providers_openid( + self, + mock_list_providers, + ): + """ + GIVEN: + - Configured user and openid social account provider + WHEN: + - API call is made to get social account providers + THEN: + - Brands for openid provider are returned + """ + + # see allauth.socialaccount.providers.openid.provider.OpenIDProvider + class MockOpenIDProvider: + id = "openid" + name = "OpenID" + + def get_brands(self): + default_servers = [ + dict(id="yahoo", name="Yahoo", openid_url="http://me.yahoo.com"), + dict(id="hyves", name="Hyves", openid_url="http://hyves.nl"), + ] + return default_servers + + def get_login_url(self, request, **kwargs): + return "openid/login/" + + mock_list_providers.return_value = [ + MockOpenIDProvider(), + ] + + response = self.client.get(f"{self.ENDPOINT}social_account_providers/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + len(response.data), + 2, + ) + def test_disconnect_social_account(self): """ GIVEN: diff --git a/src/paperless/views.py b/src/paperless/views.py index cf23f6181..1151ceed5 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -209,17 +209,14 @@ class SocialAccountProvidersView(APIView): if p.id != "openid" ] - if ( - openid_provider := next(filter(lambda p: p.id == "openid", providers), None) - is not None - ): + for openid_provider in filter(lambda p: p.id == "openid", providers): resp += [ { - "name": b.name, + "name": b["name"], "login_url": openid_provider.get_login_url( request, process="connect", - openid=b.openid_url, + openid=b["openid_url"], ), } for b in openid_provider.get_brands()