diff --git a/src/paperless/tests/test_adapter.py b/src/paperless/tests/test_adapter.py index c1da288ba..f07e0b422 100644 --- a/src/paperless/tests/test_adapter.py +++ b/src/paperless/tests/test_adapter.py @@ -1,11 +1,12 @@ -from contextvars import ContextVar from unittest import mock from allauth.account.adapter import get_adapter +from allauth.core import context from allauth.socialaccount.adapter import get_adapter as get_social_adapter from django.conf import settings from django.http import HttpRequest from django.test import TestCase +from django.test import override_settings from django.urls import reverse @@ -21,31 +22,30 @@ class TestCustomAccountAdapter(TestCase): settings.ACCOUNT_ALLOW_SIGNUPS = False self.assertFalse(adapter.is_open_for_signup(None)) - @mock.patch("allauth.core.context._request_var") - def test_is_safe_url(self, mock_request_var): + def test_is_safe_url(self): request = HttpRequest() request.get_host = mock.Mock(return_value="example.com") - mock_request_var.return_value = ContextVar("request", default=request) - adapter = get_adapter() + with context.request_context(request): + adapter = get_adapter() + with override_settings(ALLOWED_HOSTS=["*"]): - settings.ALLOWED_HOSTS = ["*"] - # True because request host is same - url = "https://example.com" - self.assertTrue(adapter.is_safe_url(url)) + # True because request host is same + url = "https://example.com" + self.assertTrue(adapter.is_safe_url(url)) - url = "https://evil.com" - # False despite wildcard because request host is different - self.assertFalse(adapter.is_safe_url(url)) + url = "https://evil.com" + # False despite wildcard because request host is different + self.assertFalse(adapter.is_safe_url(url)) - settings.ALLOWED_HOSTS = ["example.com"] - url = "https://example.com" - # True because request host is same - self.assertTrue(adapter.is_safe_url(url)) + settings.ALLOWED_HOSTS = ["example.com"] + url = "https://example.com" + # True because request host is same + self.assertTrue(adapter.is_safe_url(url)) - settings.ALLOWED_HOSTS = ["*", "example.com"] - url = "//evil.com" - # False because request host is not in allowed hosts - self.assertFalse(adapter.is_safe_url(url)) + settings.ALLOWED_HOSTS = ["*", "example.com"] + url = "//evil.com" + # False because request host is not in allowed hosts + self.assertFalse(adapter.is_safe_url(url)) class TestCustomSocialAccountAdapter(TestCase):