From 855543377c6500b4213e9a69ab1300eb631c258d Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 7 Oct 2024 10:58:47 -0700 Subject: [PATCH] Migrate to httpx_oauth --- Pipfile | 1 + Pipfile.lock | 13 +- src/documents/views.py | 16 +- src/paperless_mail/mail.py | 6 +- src/paperless_mail/oauth.py | 230 ++++++++++---------- src/paperless_mail/tests/test_mail.py | 41 ++-- src/paperless_mail/tests/test_mail_oauth.py | 160 ++++++-------- src/paperless_mail/views.py | 88 ++++---- 8 files changed, 279 insertions(+), 276 deletions(-) diff --git a/Pipfile b/Pipfile index c2db33487..7c21379fd 100644 --- a/Pipfile +++ b/Pipfile @@ -58,6 +58,7 @@ whitenoise = "~=6.7" whoosh = "~=2.7" zxing-cpp = {version = "*", platform_machine = "== 'x86_64'"} jinja2 = "~=3.1" +httpx-oauth = "*" [dev-packages] # Linting diff --git a/Pipfile.lock b/Pipfile.lock index 675e89c10..8ed6a6861 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "1e113d0879e4e0bc3c384115057647ac8d9be05252dd7c708a1fc873f294ef28" + "sha256": "584249cbeaf29659c975000b5e02b12e45d768d795e4a8ac36118e73bd7c0b8a" }, "pipfile-spec": 6, "requires": {}, @@ -799,9 +799,18 @@ "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2" ], - "markers": "python_version >= '3.9'", + "markers": "python_version >= '3.8'", "version": "==0.27.2" }, + "httpx-oauth": { + "hashes": [ + "sha256:4094cf0938fc7252b5f5dfd62cd1ab5aee2fcb6734e621942ee17d1af4806b74", + "sha256:89b45f250e93e42bbe9631adf349cab0e3d3ced958c07e06651735198d1bdf00" + ], + "index": "pypi", + "markers": "python_version >= '3.8'", + "version": "==0.15.1" + }, "humanize": { "hashes": [ "sha256:06b6eb0293e4b85e8d385397c5868926820db32b9b654b932f57fa41c23c9978", diff --git a/src/documents/views.py b/src/documents/views.py index fea33aa4a..62bffcfac 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -160,8 +160,7 @@ from paperless.serialisers import UserSerializer from paperless.views import StandardPagination from paperless_mail.models import MailAccount from paperless_mail.models import MailRule -from paperless_mail.oauth import generate_gmail_oauth_url -from paperless_mail.oauth import generate_outlook_oauth_url +from paperless_mail.oauth import PaperlessMailOAuth2Manager from paperless_mail.serialisers import MailAccountSerializer from paperless_mail.serialisers import MailRuleSerializer @@ -1586,11 +1585,14 @@ class UiSettingsView(GenericAPIView): ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED - if settings.GMAIL_OAUTH_ENABLED: - ui_settings["gmail_oauth_url"] = generate_gmail_oauth_url() - - if settings.OUTLOOK_OAUTH_ENABLED: - ui_settings["outlook_oauth_url"] = generate_outlook_oauth_url() + if settings.GMAIL_OAUTH_ENABLED or settings.OUTLOOK_OAUTH_ENABLED: + manager = PaperlessMailOAuth2Manager() + if settings.GMAIL_OAUTH_ENABLED: + ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url() + if settings.OUTLOOK_OAUTH_ENABLED: + ui_settings["outlook_oauth_url"] = ( + manager.get_outlook_authorization_url() + ) user_resp = { "id": user.id, diff --git a/src/paperless_mail/mail.py b/src/paperless_mail/mail.py index a20cf9c9c..5c1c13a41 100644 --- a/src/paperless_mail/mail.py +++ b/src/paperless_mail/mail.py @@ -1,3 +1,4 @@ +import asyncio import datetime import itertools import logging @@ -43,7 +44,7 @@ from documents.tasks import consume_file from paperless_mail.models import MailAccount from paperless_mail.models import MailRule from paperless_mail.models import ProcessedMail -from paperless_mail.oauth import refresh_oauth_token +from paperless_mail.oauth import PaperlessMailOAuth2Manager from paperless_mail.preprocessor import MailMessageDecryptor from paperless_mail.preprocessor import MailMessagePreprocessor @@ -537,7 +538,8 @@ class MailAccountHandler(LoggingMixin): and account.expiration is not None and account.expiration < timezone.now() ): - if refresh_oauth_token(account): + manager = PaperlessMailOAuth2Manager() + if asyncio.run(manager.refresh_account_oauth_token(account)): account.refresh_from_db() else: return total_processed_files diff --git a/src/paperless_mail/oauth.py b/src/paperless_mail/oauth.py index b15f2c30d..67fbcab9e 100644 --- a/src/paperless_mail/oauth.py +++ b/src/paperless_mail/oauth.py @@ -1,122 +1,132 @@ +import asyncio import logging from datetime import timedelta -import httpx from django.conf import settings from django.utils import timezone +from httpx_oauth.oauth2 import OAuth2 +from httpx_oauth.oauth2 import OAuth2Token +from httpx_oauth.oauth2 import RefreshTokenError from paperless_mail.models import MailAccount -GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token" -GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth" -OUTLOOK_OAUTH_ENDPOINT_TOKEN = ( - "https://login.microsoftonline.com/common/oauth2/v2.0/token" -) -OUTLOOK_OAUTH_ENDPOINT_AUTH = ( - "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" -) - -def get_oauth_callback_url() -> str: - return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/" - - -def get_oauth_redirect_url() -> str: - return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail" - - -def generate_gmail_oauth_url() -> str: - response_type = "code" - client_id = settings.GMAIL_OAUTH_CLIENT_ID - redirect_uri = get_oauth_callback_url() - scope = "https://mail.google.com/" - access_type = "offline" - url = f"{GMAIL_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}&access_type={access_type}&prompt=consent" - return url - - -def generate_outlook_oauth_url() -> str: - response_type = "code" - client_id = settings.OUTLOOK_OAUTH_CLIENT_ID - redirect_uri = get_oauth_callback_url() - scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All" - url = f"{OUTLOOK_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&response_mode=query&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}" - return url - - -def generate_gmail_oauth_token_request_data(code: str) -> dict: - client_id = settings.GMAIL_OAUTH_CLIENT_ID - client_secret = settings.GMAIL_OAUTH_CLIENT_SECRET - scope = "https://mail.google.com/" - - return { - "code": code, - "client_id": client_id, - "client_secret": client_secret, - "scope": scope, - "redirect_uri": get_oauth_callback_url(), - "grant_type": "authorization_code", - } - - -def generate_outlook_oauth_token_request_data(code: str) -> dict: - client_id = settings.OUTLOOK_OAUTH_CLIENT_ID - client_secret = settings.OUTLOOK_OAUTH_CLIENT_SECRET - scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All" - - return { - "code": code, - "client_id": client_id, - "client_secret": client_secret, - "scope": scope, - "redirect_uri": get_oauth_callback_url(), - "grant_type": "authorization_code", - } - - -def refresh_oauth_token(account: MailAccount) -> bool: - """ - Refreshes the oauth token for the given mail account. - """ - logger = logging.getLogger("paperless_mail") - logger.debug(f"Attempting to refresh oauth token for account {account}") - if not account.refresh_token: - logger.error(f"Account {account}: No refresh token available.") - return False - - if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH: - url = GMAIL_OAUTH_ENDPOINT_TOKEN - data = { - "client_id": settings.GMAIL_OAUTH_CLIENT_ID, - "client_secret": settings.GMAIL_OAUTH_CLIENT_SECRET, - "refresh_token": account.refresh_token, - "grant_type": "refresh_token", - } - elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH: - url = OUTLOOK_OAUTH_ENDPOINT_TOKEN - data = { - "client_id": settings.OUTLOOK_OAUTH_CLIENT_ID, - "client_secret": settings.OUTLOOK_OAUTH_CLIENT_SECRET, - "refresh_token": account.refresh_token, - "grant_type": "refresh_token", - } - - response = httpx.post( - url=url, - data=data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, +class PaperlessMailOAuth2Manager: + GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token" + GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth" + OUTLOOK_OAUTH_ENDPOINT_TOKEN = ( + "https://login.microsoftonline.com/common/oauth2/v2.0/token" ) - data = response.json() - if response.status_code < 400 and "access_token" in data: - account.password = data["access_token"] - account.expiration = timezone.now() + timedelta( - seconds=data["expires_in"], + OUTLOOK_OAUTH_ENDPOINT_AUTH = ( + "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + ) + + def __init__(self): + self._gmail_client = None + self._outlook_client = None + + @property + def gmail_client(self) -> OAuth2: + if self._gmail_client is None: + self._gmail_client = OAuth2( + settings.GMAIL_OAUTH_CLIENT_ID, + settings.GMAIL_OAUTH_CLIENT_SECRET, + self.GMAIL_OAUTH_ENDPOINT_AUTH, + self.GMAIL_OAUTH_ENDPOINT_TOKEN, + refresh_token_endpoint=self.GMAIL_OAUTH_ENDPOINT_TOKEN, + token_endpoint_auth_method="client_secret_post", + ) + return self._gmail_client + + @property + def outlook_client(self) -> OAuth2: + if self._outlook_client is None: + self._outlook_client = OAuth2( + settings.OUTLOOK_OAUTH_CLIENT_ID, + settings.OUTLOOK_OAUTH_CLIENT_SECRET, + self.OUTLOOK_OAUTH_ENDPOINT_AUTH, + self.OUTLOOK_OAUTH_ENDPOINT_TOKEN, + refresh_token_endpoint=self.OUTLOOK_OAUTH_ENDPOINT_TOKEN, + token_endpoint_auth_method="client_secret_post", + ) + return self._outlook_client + + @property + def oauth_callback_url(self) -> str: + return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/" + + @property + def oauth_redirect_url(self) -> str: + return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail" + + def get_gmail_authorization_url(self) -> str: + return asyncio.run( + self.gmail_client.get_authorization_url( + redirect_uri=self.oauth_callback_url, + scope=["https://mail.google.com/"], + extras_params={"prompt": "consent", "access_type": "offline"}, + ), ) - account.save() - logger.debug(f"Successfully refreshed oauth token for account {account}") - return True - else: - logger.error( - f"Failed to refresh oauth token for account {account}: {response}", + + def get_outlook_authorization_url(self) -> str: + return asyncio.run( + self.outlook_client.get_authorization_url( + redirect_uri=self.oauth_callback_url, + scope=[ + "offline_access", + "https://outlook.office.com/IMAP.AccessAsUser.All", + ], + extras_params={"response_type": "code"}, + ), ) - return False + + def get_gmail_access_token(self, code: str) -> OAuth2Token: + return asyncio.run( + self.gmail_client.get_access_token( + code=code, + redirect_uri=self.oauth_callback_url, + ), + ) + + def get_outlook_access_token(self, code: str) -> OAuth2Token: + return asyncio.run( + self.outlook_client.get_access_token( + code=code, + redirect_uri=self.oauth_callback_url, + ), + ) + + def refresh_account_oauth_token(self, account: MailAccount) -> bool: + """ + Refreshes the oauth token for the given mail account. + """ + logger = logging.getLogger("paperless_mail") + logger.debug(f"Attempting to refresh oauth token for account {account}") + if not account.refresh_token: + logger.error(f"Account {account}: No refresh token available.") + return False + + try: + result: OAuth2Token + if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH: + result = asyncio.run( + self.gmail_client.refresh_token( + refresh_token=account.refresh_token, + ), + ) + elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH: + result = asyncio.run( + self.outlook_client.refresh_token( + refresh_token=account.refresh_token, + ), + ) + account.password = result["access_token"] + account.expiration = timezone.now() + timedelta( + seconds=result["expires_in"], + ) + account.save() + logger.debug(f"Successfully refreshed oauth token for account {account}") + return True + except RefreshTokenError as e: + logger.error(f"Failed to refresh oauth token for account {account}: {e}") + return False diff --git a/src/paperless_mail/tests/test_mail.py b/src/paperless_mail/tests/test_mail.py index c2838f8f4..c6d0f04bb 100644 --- a/src/paperless_mail/tests/test_mail.py +++ b/src/paperless_mail/tests/test_mail.py @@ -35,7 +35,6 @@ from paperless_mail.mail import apply_mail_action from paperless_mail.models import MailAccount from paperless_mail.models import MailRule from paperless_mail.models import ProcessedMail -from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN @dataclasses.dataclass @@ -1636,11 +1635,19 @@ class TestMailAccountTestView(APITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.content.decode(), "Unable to connect to server") - @mock.patch("httpx.post") + @mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager") def test_mail_account_test_view_refresh_token( self, - mock_post, + mock_manager, ): + """ + GIVEN: + - Mail account with expired token + WHEN: + - Mail account is tested + THEN: + - Should refresh the token + """ existing_account = MailAccount.objects.create( imap_server="imap.example.com", imap_port=993, @@ -1653,11 +1660,7 @@ class TestMailAccountTestView(APITestCase): is_token=True, ) - mock_post.return_value.status_code = status.HTTP_200_OK - mock_post.return_value.json.return_value = { - "access_token": "newtoken", - "expires_in": 3600, - } + mock_manager.return_value.refresh_account_oauth_token.return_value = True data = { "id": existing_account.id, "imap_server": "imap.example.com", @@ -1668,13 +1671,22 @@ class TestMailAccountTestView(APITestCase): "is_token": True, } self.client.post(self.url, data, format="json") - self.assertEqual(mock_post.call_args[1]["url"], GMAIL_OAUTH_ENDPOINT_TOKEN) + self.assertEqual(mock_manager.call_count, 1) - @mock.patch("httpx.post") + @mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager") def test_mail_account_test_view_refresh_token_fails( self, - mock_post, + mock_manager, ): + """ + GIVEN: + - Mail account with expired token + WHEN: + - Mail account is tested + - Token refresh fails + THEN: + - Should log an error + """ existing_account = MailAccount.objects.create( imap_server="imap.example.com", imap_port=993, @@ -1687,10 +1699,7 @@ class TestMailAccountTestView(APITestCase): is_token=True, ) - mock_post.return_value.status_code = status.HTTP_400_BAD_REQUEST - mock_post.return_value.json.return_value = { - "error": "invalid_grant", - } + mock_manager.return_value.refresh_account_oauth_token.return_value = False data = { "id": existing_account.id, "imap_server": "imap.example.com", @@ -1704,5 +1713,5 @@ class TestMailAccountTestView(APITestCase): response = self.client.post(self.url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) error_str = cm.output[0] - expected_str = "Failed to refresh oauth token for account" + expected_str = "Unable to refresh oauth token" self.assertIn(expected_str, error_str) diff --git a/src/paperless_mail/tests/test_mail_oauth.py b/src/paperless_mail/tests/test_mail_oauth.py index a456316f2..e269b74ac 100644 --- a/src/paperless_mail/tests/test_mail_oauth.py +++ b/src/paperless_mail/tests/test_mail_oauth.py @@ -6,15 +6,12 @@ from django.contrib.auth.models import User from django.test import TestCase from django.test import override_settings from django.utils import timezone +from httpx_oauth.oauth2 import GetAccessTokenError from rest_framework import status from paperless_mail.mail import MailAccountHandler from paperless_mail.models import MailAccount -from paperless_mail.oauth import generate_gmail_oauth_url -from paperless_mail.oauth import generate_outlook_oauth_url -from paperless_mail.oauth import get_oauth_callback_url -from paperless_mail.oauth import get_oauth_redirect_url -from paperless_mail.oauth import refresh_oauth_token +from paperless_mail.oauth import PaperlessMailOAuth2Manager class TestMailOAuth( @@ -42,9 +39,10 @@ class TestMailOAuth( - Correct URLs are generated """ # Callback URL + oauth_manager = PaperlessMailOAuth2Manager() with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"): self.assertEqual( - get_oauth_callback_url(), + oauth_manager.oauth_callback_url, "http://paperless.example.com/api/oauth/callback/", ) with override_settings( @@ -52,7 +50,7 @@ class TestMailOAuth( PAPERLESS_URL="http://paperless.example.com", ): self.assertEqual( - get_oauth_callback_url(), + oauth_manager.oauth_callback_url, "http://paperless.example.com/api/oauth/callback/", ) with override_settings( @@ -61,42 +59,33 @@ class TestMailOAuth( BASE_URL="/paperless/", ): self.assertEqual( - get_oauth_callback_url(), + oauth_manager.oauth_callback_url, "http://paperless.example.com/paperless/api/oauth/callback/", ) # Redirect URL with override_settings(DEBUG=True): self.assertEqual( - get_oauth_redirect_url(), + oauth_manager.oauth_redirect_url, "http://localhost:4200/mail", ) with override_settings(DEBUG=False): self.assertEqual( - get_oauth_redirect_url(), + oauth_manager.oauth_redirect_url, "/mail", ) - def test_generate_oauth_urls(self): - """ - GIVEN: - - Mocked settings for Gmail and Outlook OAuth client IDs - WHEN: - - generate_gmail_oauth_url and generate_outlook_oauth_url are called - THEN: - - Correct URLs are generated - """ - self.assertEqual( - "https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=test_gmail_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=https://mail.google.com/&access_type=offline&prompt=consent", - generate_gmail_oauth_url(), - ) - self.assertEqual( - "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=test_outlook_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=offline_access https://outlook.office.com/IMAP.AccessAsUser.All", - generate_outlook_oauth_url(), - ) - - @mock.patch("httpx.post") - def test_oauth_callback_view(self, mock_post): + @mock.patch( + "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_gmail_access_token", + ) + @mock.patch( + "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_outlook_access_token", + ) + def test_oauth_callback_view( + self, + mock_get_outlook_access_token, + mock_get_gmail_access_token, + ): """ GIVEN: - Mocked settings for Gmail and Outlook OAuth client IDs and secrets @@ -106,7 +95,12 @@ class TestMailOAuth( - Gmail mail account is created """ - mock_post.return_value.json.return_value = { + mock_get_gmail_access_token.return_value = { + "access_token": "test_access_token", + "refresh_token": "test_refresh_token", + "expires_in": 3600, + } + mock_get_outlook_access_token.return_value = { "access_token": "test_access_token", "refresh_token": "test_refresh_token", "expires_in": 3600, @@ -118,7 +112,7 @@ class TestMailOAuth( ) self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertIn("oauth_success=1", response.url) - mock_post.assert_called_once() + mock_get_gmail_access_token.assert_called_once() self.assertTrue( MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), ) @@ -152,37 +146,51 @@ class TestMailOAuth( MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), ) - @mock.patch("httpx.post") - def test_oauth_callback_view_error(self, mock_post): + @mock.patch("httpx_oauth.oauth2.BaseOAuth2.get_access_token") + def test_oauth_callback_view_error(self, mock_get_access_token): """ GIVEN: - Mocked settings for Gmail and Outlook OAuth client IDs and secrets WHEN: - OAuth callback is called with an error THEN: + - No mail account is created - Error is logged """ + mock_get_access_token.side_effect = GetAccessTokenError("test_error") - mock_post.return_value.json.return_value = { - "error": "test_error", - } + with self.assertLogs("paperless_mail", level="ERROR") as cm: + # Test Google OAuth callback + response = self.client.get( + "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", + ) + self.assertEqual(response.status_code, status.HTTP_302_FOUND) + self.assertIn("oauth_success=0", response.url) + self.assertFalse( + MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), + ) - response = self.client.get( - "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", - ) - self.assertEqual(response.status_code, status.HTTP_302_FOUND) - self.assertIn("oauth_success=0", response.url) - mock_post.assert_called_once() - self.assertFalse( - MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), - ) - self.assertFalse( - MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), - ) + # Test Outlook OAuth callback + response = self.client.get("/api/oauth/callback/?code=test_code") + self.assertEqual(response.status_code, status.HTTP_302_FOUND) + self.assertIn("oauth_success=0", response.url) + self.assertFalse( + MailAccount.objects.filter( + imap_server="outlook.office365.com", + ).exists(), + ) + + self.assertIn("Error getting access token: test_error", cm.output[0]) @mock.patch("paperless_mail.mail.get_mailbox") - @mock.patch("httpx.post") - def test_refresh_token_on_handle_mail_account(self, mock_post, mock_get_mailbox): + @mock.patch( + "paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token", + ) + def test_refresh_token_on_handle_mail_account( + self, + mock_refresh_account_oauth_token, + mock_get_mailbox, + ): """ GIVEN: - Mail account with refresh token and expiration @@ -206,16 +214,13 @@ class TestMailOAuth( expiration=timezone.now() - timedelta(days=1), ) - mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "access_token": "test_access_token", - "refresh_token": "test_refresh_token", - "expires_in": 3600, - } + mock_refresh_account_oauth_token.return_value = True self.mail_account_handler.handle_mail_account(mail_account) - mock_post.assert_called_once() + mock_refresh_account_oauth_token.assert_called_once() + mock_refresh_account_oauth_token.reset_mock() + mock_refresh_account_oauth_token.return_value = True outlook_mail_account = MailAccount.objects.create( name="Test Outlook Mail Account", username="test_username", @@ -228,13 +233,15 @@ class TestMailOAuth( ) self.mail_account_handler.handle_mail_account(outlook_mail_account) - mock_post.assert_called() + mock_refresh_account_oauth_token.assert_called_once() @mock.patch("paperless_mail.mail.get_mailbox") - @mock.patch("httpx.post") + @mock.patch( + "paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token", + ) def test_refresh_token_on_handle_mail_account_fails( self, - mock_post, + mock_refresh_account_oauth_token, mock_get_mailbox, ): """ @@ -261,37 +268,10 @@ class TestMailOAuth( expiration=timezone.now() - timedelta(days=1), ) - mock_post.return_value.status_code = 400 - mock_post.return_value.json.return_value = { - "error": "test_error", - } + mock_refresh_account_oauth_token.return_value = False self.assertEqual( self.mail_account_handler.handle_mail_account(mail_account), 0, ) - mock_post.assert_called_once() - - def test_refresh_token_invalid_account(self): - """ - GIVEN: - - Mail account without refresh token - WHEN: - - refresh_oauth_token is called - THEN: - - False is returned - """ - - mail_account = MailAccount.objects.create( - name="test_mail_account", - username="test_username", - imap_security=MailAccount.ImapSecurity.SSL, - imap_port=993, - account_type=MailAccount.MailAccountType.GMAIL_OAUTH, - is_token=True, - expiration=timezone.now() - timedelta(days=1), - ) - - self.assertFalse( - refresh_oauth_token(mail_account), - ) + mock_refresh_account_oauth_token.assert_called_once() diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index 7f9c82eb3..718dc4eec 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -2,10 +2,10 @@ import datetime import logging from datetime import timedelta -import httpx from django.http import HttpResponseBadRequest from django.http import HttpResponseRedirect from django.utils import timezone +from httpx_oauth.oauth2 import GetAccessTokenError from rest_framework.generics import GenericAPIView from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -20,12 +20,7 @@ from paperless_mail.mail import get_mailbox from paperless_mail.mail import mailbox_login from paperless_mail.models import MailAccount from paperless_mail.models import MailRule -from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN -from paperless_mail.oauth import OUTLOOK_OAUTH_ENDPOINT_TOKEN -from paperless_mail.oauth import generate_gmail_oauth_token_request_data -from paperless_mail.oauth import generate_outlook_oauth_token_request_data -from paperless_mail.oauth import get_oauth_redirect_url -from paperless_mail.oauth import refresh_oauth_token +from paperless_mail.oauth import PaperlessMailOAuth2Manager from paperless_mail.serialisers import MailAccountSerializer from paperless_mail.serialisers import MailRuleSerializer @@ -83,7 +78,8 @@ class MailAccountTestView(GenericAPIView): and account.expiration is not None and account.expiration < timezone.now() ): - if refresh_oauth_token(existing_account): + oauth_manager = PaperlessMailOAuth2Manager() + if oauth_manager.refresh_account_oauth_token(existing_account): # User is not changing password and token needs to be refreshed existing_account.refresh_from_db() account.password = existing_account.password @@ -114,50 +110,39 @@ class OauthCallbackView(GenericAPIView): ) return HttpResponseBadRequest("Invalid request, see logs for more detail") - if scope is not None and "google" in scope: - # Google - account_type = MailAccount.MailAccountType.GMAIL_OAUTH - imap_server = "imap.gmail.com" - defaults = { - "name": f"Gmail OAuth {timezone.now()}", - "username": "", - "imap_security": MailAccount.ImapSecurity.SSL, - "imap_port": 993, - "account_type": account_type, - } - token_request_uri = GMAIL_OAUTH_ENDPOINT_TOKEN - data = generate_gmail_oauth_token_request_data(code) + oauth_manager = PaperlessMailOAuth2Manager() - elif scope is None: - # Outlook - account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH - imap_server = "outlook.office365.com" - defaults = { - "name": f"Outlook OAuth {timezone.now()}", - "username": "", - "imap_security": MailAccount.ImapSecurity.SSL, - "imap_port": 993, - "account_type": account_type, - } + try: + if scope is not None and "google" in scope: + # Google + account_type = MailAccount.MailAccountType.GMAIL_OAUTH + imap_server = "imap.gmail.com" + defaults = { + "name": f"Gmail OAuth {timezone.now()}", + "username": "", + "imap_security": MailAccount.ImapSecurity.SSL, + "imap_port": 993, + "account_type": account_type, + } + result = oauth_manager.get_gmail_access_token(code) - token_request_uri = OUTLOOK_OAUTH_ENDPOINT_TOKEN - data = generate_outlook_oauth_token_request_data(code) + elif scope is None: + # Outlook + account_type = MailAccount.MailAccountType.OUTLOOK_OAUTH + imap_server = "outlook.office365.com" + defaults = { + "name": f"Outlook OAuth {timezone.now()}", + "username": "", + "imap_security": MailAccount.ImapSecurity.SSL, + "imap_port": 993, + "account_type": account_type, + } - headers = { - "Content-Type": "application/x-www-form-urlencoded", - } - response = httpx.post(token_request_uri, data=data, headers=headers) - data = response.json() + result = oauth_manager.get_outlook_access_token(code) - if "error" in data: - logger.error(f"Error {response.status_code} getting access token: {data}") - return HttpResponseRedirect( - f"{get_oauth_redirect_url()}?oauth_success=0", - ) - elif "access_token" in data: - access_token = data["access_token"] - refresh_token = data["refresh_token"] - expires_in = data["expires_in"] + access_token = result["access_token"] + refresh_token = result["refresh_token"] + expires_in = result["expires_in"] account, _ = MailAccount.objects.update_or_create( password=access_token, is_token=True, @@ -167,5 +152,10 @@ class OauthCallbackView(GenericAPIView): defaults=defaults, ) return HttpResponseRedirect( - f"{get_oauth_redirect_url()}?oauth_success=1&account_id={account.pk}", + f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}", + ) + except GetAccessTokenError as e: + logger.error(f"Error getting access token: {e}") + return HttpResponseRedirect( + f"{oauth_manager.oauth_redirect_url}?oauth_success=0", )