Migrate to httpx_oauth

This commit is contained in:
shamoon 2024-10-07 10:58:47 -07:00
parent ad483b49a1
commit 855543377c
8 changed files with 279 additions and 276 deletions

View File

@ -58,6 +58,7 @@ whitenoise = "~=6.7"
whoosh = "~=2.7" whoosh = "~=2.7"
zxing-cpp = {version = "*", platform_machine = "== 'x86_64'"} zxing-cpp = {version = "*", platform_machine = "== 'x86_64'"}
jinja2 = "~=3.1" jinja2 = "~=3.1"
httpx-oauth = "*"
[dev-packages] [dev-packages]
# Linting # Linting

13
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "1e113d0879e4e0bc3c384115057647ac8d9be05252dd7c708a1fc873f294ef28" "sha256": "584249cbeaf29659c975000b5e02b12e45d768d795e4a8ac36118e73bd7c0b8a"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": {}, "requires": {},
@ -799,9 +799,18 @@
"sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0",
"sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2" "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"
], ],
"markers": "python_version >= '3.9'", "markers": "python_version >= '3.8'",
"version": "==0.27.2" "version": "==0.27.2"
}, },
"httpx-oauth": {
"hashes": [
"sha256:4094cf0938fc7252b5f5dfd62cd1ab5aee2fcb6734e621942ee17d1af4806b74",
"sha256:89b45f250e93e42bbe9631adf349cab0e3d3ced958c07e06651735198d1bdf00"
],
"index": "pypi",
"markers": "python_version >= '3.8'",
"version": "==0.15.1"
},
"humanize": { "humanize": {
"hashes": [ "hashes": [
"sha256:06b6eb0293e4b85e8d385397c5868926820db32b9b654b932f57fa41c23c9978", "sha256:06b6eb0293e4b85e8d385397c5868926820db32b9b654b932f57fa41c23c9978",

View File

@ -160,8 +160,7 @@ from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination from paperless.views import StandardPagination
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.oauth import generate_gmail_oauth_url from paperless_mail.oauth import PaperlessMailOAuth2Manager
from paperless_mail.oauth import generate_outlook_oauth_url
from paperless_mail.serialisers import MailAccountSerializer from paperless_mail.serialisers import MailAccountSerializer
from paperless_mail.serialisers import MailRuleSerializer from paperless_mail.serialisers import MailRuleSerializer
@ -1586,11 +1585,14 @@ class UiSettingsView(GenericAPIView):
ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED
if settings.GMAIL_OAUTH_ENABLED or settings.OUTLOOK_OAUTH_ENABLED:
manager = PaperlessMailOAuth2Manager()
if settings.GMAIL_OAUTH_ENABLED: if settings.GMAIL_OAUTH_ENABLED:
ui_settings["gmail_oauth_url"] = generate_gmail_oauth_url() ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url()
if settings.OUTLOOK_OAUTH_ENABLED: if settings.OUTLOOK_OAUTH_ENABLED:
ui_settings["outlook_oauth_url"] = generate_outlook_oauth_url() ui_settings["outlook_oauth_url"] = (
manager.get_outlook_authorization_url()
)
user_resp = { user_resp = {
"id": user.id, "id": user.id,

View File

@ -1,3 +1,4 @@
import asyncio
import datetime import datetime
import itertools import itertools
import logging import logging
@ -43,7 +44,7 @@ from documents.tasks import consume_file
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.models import ProcessedMail from paperless_mail.models import ProcessedMail
from paperless_mail.oauth import refresh_oauth_token from paperless_mail.oauth import PaperlessMailOAuth2Manager
from paperless_mail.preprocessor import MailMessageDecryptor from paperless_mail.preprocessor import MailMessageDecryptor
from paperless_mail.preprocessor import MailMessagePreprocessor from paperless_mail.preprocessor import MailMessagePreprocessor
@ -537,7 +538,8 @@ class MailAccountHandler(LoggingMixin):
and account.expiration is not None and account.expiration is not None
and account.expiration < timezone.now() and account.expiration < timezone.now()
): ):
if refresh_oauth_token(account): manager = PaperlessMailOAuth2Manager()
if asyncio.run(manager.refresh_account_oauth_token(account)):
account.refresh_from_db() account.refresh_from_db()
else: else:
return total_processed_files return total_processed_files

View File

@ -1,80 +1,102 @@
import asyncio
import logging import logging
from datetime import timedelta from datetime import timedelta
import httpx
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
from httpx_oauth.oauth2 import OAuth2
from httpx_oauth.oauth2 import OAuth2Token
from httpx_oauth.oauth2 import RefreshTokenError
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token"
GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth" class PaperlessMailOAuth2Manager:
OUTLOOK_OAUTH_ENDPOINT_TOKEN = ( GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token"
GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth"
OUTLOOK_OAUTH_ENDPOINT_TOKEN = (
"https://login.microsoftonline.com/common/oauth2/v2.0/token" "https://login.microsoftonline.com/common/oauth2/v2.0/token"
) )
OUTLOOK_OAUTH_ENDPOINT_AUTH = ( OUTLOOK_OAUTH_ENDPOINT_AUTH = (
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize" "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
) )
def __init__(self):
self._gmail_client = None
self._outlook_client = None
def get_oauth_callback_url() -> str: @property
def gmail_client(self) -> OAuth2:
if self._gmail_client is None:
self._gmail_client = OAuth2(
settings.GMAIL_OAUTH_CLIENT_ID,
settings.GMAIL_OAUTH_CLIENT_SECRET,
self.GMAIL_OAUTH_ENDPOINT_AUTH,
self.GMAIL_OAUTH_ENDPOINT_TOKEN,
refresh_token_endpoint=self.GMAIL_OAUTH_ENDPOINT_TOKEN,
token_endpoint_auth_method="client_secret_post",
)
return self._gmail_client
@property
def outlook_client(self) -> OAuth2:
if self._outlook_client is None:
self._outlook_client = OAuth2(
settings.OUTLOOK_OAUTH_CLIENT_ID,
settings.OUTLOOK_OAUTH_CLIENT_SECRET,
self.OUTLOOK_OAUTH_ENDPOINT_AUTH,
self.OUTLOOK_OAUTH_ENDPOINT_TOKEN,
refresh_token_endpoint=self.OUTLOOK_OAUTH_ENDPOINT_TOKEN,
token_endpoint_auth_method="client_secret_post",
)
return self._outlook_client
@property
def oauth_callback_url(self) -> str:
return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/" return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/"
@property
def get_oauth_redirect_url() -> str: def oauth_redirect_url(self) -> str:
return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail" return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail"
def get_gmail_authorization_url(self) -> str:
return asyncio.run(
self.gmail_client.get_authorization_url(
redirect_uri=self.oauth_callback_url,
scope=["https://mail.google.com/"],
extras_params={"prompt": "consent", "access_type": "offline"},
),
)
def generate_gmail_oauth_url() -> str: def get_outlook_authorization_url(self) -> str:
response_type = "code" return asyncio.run(
client_id = settings.GMAIL_OAUTH_CLIENT_ID self.outlook_client.get_authorization_url(
redirect_uri = get_oauth_callback_url() redirect_uri=self.oauth_callback_url,
scope = "https://mail.google.com/" scope=[
access_type = "offline" "offline_access",
url = f"{GMAIL_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}&access_type={access_type}&prompt=consent" "https://outlook.office.com/IMAP.AccessAsUser.All",
return url ],
extras_params={"response_type": "code"},
),
)
def get_gmail_access_token(self, code: str) -> OAuth2Token:
return asyncio.run(
self.gmail_client.get_access_token(
code=code,
redirect_uri=self.oauth_callback_url,
),
)
def generate_outlook_oauth_url() -> str: def get_outlook_access_token(self, code: str) -> OAuth2Token:
response_type = "code" return asyncio.run(
client_id = settings.OUTLOOK_OAUTH_CLIENT_ID self.outlook_client.get_access_token(
redirect_uri = get_oauth_callback_url() code=code,
scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All" redirect_uri=self.oauth_callback_url,
url = f"{OUTLOOK_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&response_mode=query&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}" ),
return url )
def refresh_account_oauth_token(self, account: MailAccount) -> bool:
def generate_gmail_oauth_token_request_data(code: str) -> dict:
client_id = settings.GMAIL_OAUTH_CLIENT_ID
client_secret = settings.GMAIL_OAUTH_CLIENT_SECRET
scope = "https://mail.google.com/"
return {
"code": code,
"client_id": client_id,
"client_secret": client_secret,
"scope": scope,
"redirect_uri": get_oauth_callback_url(),
"grant_type": "authorization_code",
}
def generate_outlook_oauth_token_request_data(code: str) -> dict:
client_id = settings.OUTLOOK_OAUTH_CLIENT_ID
client_secret = settings.OUTLOOK_OAUTH_CLIENT_SECRET
scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All"
return {
"code": code,
"client_id": client_id,
"client_secret": client_secret,
"scope": scope,
"redirect_uri": get_oauth_callback_url(),
"grant_type": "authorization_code",
}
def refresh_oauth_token(account: MailAccount) -> bool:
""" """
Refreshes the oauth token for the given mail account. Refreshes the oauth token for the given mail account.
""" """
@ -84,39 +106,27 @@ def refresh_oauth_token(account: MailAccount) -> bool:
logger.error(f"Account {account}: No refresh token available.") logger.error(f"Account {account}: No refresh token available.")
return False return False
try:
result: OAuth2Token
if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH: if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH:
url = GMAIL_OAUTH_ENDPOINT_TOKEN result = asyncio.run(
data = { self.gmail_client.refresh_token(
"client_id": settings.GMAIL_OAUTH_CLIENT_ID, refresh_token=account.refresh_token,
"client_secret": settings.GMAIL_OAUTH_CLIENT_SECRET, ),
"refresh_token": account.refresh_token,
"grant_type": "refresh_token",
}
elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH:
url = OUTLOOK_OAUTH_ENDPOINT_TOKEN
data = {
"client_id": settings.OUTLOOK_OAUTH_CLIENT_ID,
"client_secret": settings.OUTLOOK_OAUTH_CLIENT_SECRET,
"refresh_token": account.refresh_token,
"grant_type": "refresh_token",
}
response = httpx.post(
url=url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
) )
data = response.json() elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH:
if response.status_code < 400 and "access_token" in data: result = asyncio.run(
account.password = data["access_token"] self.outlook_client.refresh_token(
refresh_token=account.refresh_token,
),
)
account.password = result["access_token"]
account.expiration = timezone.now() + timedelta( account.expiration = timezone.now() + timedelta(
seconds=data["expires_in"], seconds=result["expires_in"],
) )
account.save() account.save()
logger.debug(f"Successfully refreshed oauth token for account {account}") logger.debug(f"Successfully refreshed oauth token for account {account}")
return True return True
else: except RefreshTokenError as e:
logger.error( logger.error(f"Failed to refresh oauth token for account {account}: {e}")
f"Failed to refresh oauth token for account {account}: {response}",
)
return False return False

View File

@ -35,7 +35,6 @@ from paperless_mail.mail import apply_mail_action
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.models import ProcessedMail from paperless_mail.models import ProcessedMail
from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN
@dataclasses.dataclass @dataclasses.dataclass
@ -1636,11 +1635,19 @@ class TestMailAccountTestView(APITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.content.decode(), "Unable to connect to server") self.assertEqual(response.content.decode(), "Unable to connect to server")
@mock.patch("httpx.post") @mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager")
def test_mail_account_test_view_refresh_token( def test_mail_account_test_view_refresh_token(
self, self,
mock_post, mock_manager,
): ):
"""
GIVEN:
- Mail account with expired token
WHEN:
- Mail account is tested
THEN:
- Should refresh the token
"""
existing_account = MailAccount.objects.create( existing_account = MailAccount.objects.create(
imap_server="imap.example.com", imap_server="imap.example.com",
imap_port=993, imap_port=993,
@ -1653,11 +1660,7 @@ class TestMailAccountTestView(APITestCase):
is_token=True, is_token=True,
) )
mock_post.return_value.status_code = status.HTTP_200_OK mock_manager.return_value.refresh_account_oauth_token.return_value = True
mock_post.return_value.json.return_value = {
"access_token": "newtoken",
"expires_in": 3600,
}
data = { data = {
"id": existing_account.id, "id": existing_account.id,
"imap_server": "imap.example.com", "imap_server": "imap.example.com",
@ -1668,13 +1671,22 @@ class TestMailAccountTestView(APITestCase):
"is_token": True, "is_token": True,
} }
self.client.post(self.url, data, format="json") self.client.post(self.url, data, format="json")
self.assertEqual(mock_post.call_args[1]["url"], GMAIL_OAUTH_ENDPOINT_TOKEN) self.assertEqual(mock_manager.call_count, 1)
@mock.patch("httpx.post") @mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager")
def test_mail_account_test_view_refresh_token_fails( def test_mail_account_test_view_refresh_token_fails(
self, self,
mock_post, mock_manager,
): ):
"""
GIVEN:
- Mail account with expired token
WHEN:
- Mail account is tested
- Token refresh fails
THEN:
- Should log an error
"""
existing_account = MailAccount.objects.create( existing_account = MailAccount.objects.create(
imap_server="imap.example.com", imap_server="imap.example.com",
imap_port=993, imap_port=993,
@ -1687,10 +1699,7 @@ class TestMailAccountTestView(APITestCase):
is_token=True, is_token=True,
) )
mock_post.return_value.status_code = status.HTTP_400_BAD_REQUEST mock_manager.return_value.refresh_account_oauth_token.return_value = False
mock_post.return_value.json.return_value = {
"error": "invalid_grant",
}
data = { data = {
"id": existing_account.id, "id": existing_account.id,
"imap_server": "imap.example.com", "imap_server": "imap.example.com",
@ -1704,5 +1713,5 @@ class TestMailAccountTestView(APITestCase):
response = self.client.post(self.url, data, format="json") response = self.client.post(self.url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
error_str = cm.output[0] error_str = cm.output[0]
expected_str = "Failed to refresh oauth token for account" expected_str = "Unable to refresh oauth token"
self.assertIn(expected_str, error_str) self.assertIn(expected_str, error_str)

View File

@ -6,15 +6,12 @@ from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.test import override_settings from django.test import override_settings
from django.utils import timezone from django.utils import timezone
from httpx_oauth.oauth2 import GetAccessTokenError
from rest_framework import status from rest_framework import status
from paperless_mail.mail import MailAccountHandler from paperless_mail.mail import MailAccountHandler
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.oauth import generate_gmail_oauth_url from paperless_mail.oauth import PaperlessMailOAuth2Manager
from paperless_mail.oauth import generate_outlook_oauth_url
from paperless_mail.oauth import get_oauth_callback_url
from paperless_mail.oauth import get_oauth_redirect_url
from paperless_mail.oauth import refresh_oauth_token
class TestMailOAuth( class TestMailOAuth(
@ -42,9 +39,10 @@ class TestMailOAuth(
- Correct URLs are generated - Correct URLs are generated
""" """
# Callback URL # Callback URL
oauth_manager = PaperlessMailOAuth2Manager()
with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"): with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"):
self.assertEqual( self.assertEqual(
get_oauth_callback_url(), oauth_manager.oauth_callback_url,
"http://paperless.example.com/api/oauth/callback/", "http://paperless.example.com/api/oauth/callback/",
) )
with override_settings( with override_settings(
@ -52,7 +50,7 @@ class TestMailOAuth(
PAPERLESS_URL="http://paperless.example.com", PAPERLESS_URL="http://paperless.example.com",
): ):
self.assertEqual( self.assertEqual(
get_oauth_callback_url(), oauth_manager.oauth_callback_url,
"http://paperless.example.com/api/oauth/callback/", "http://paperless.example.com/api/oauth/callback/",
) )
with override_settings( with override_settings(
@ -61,42 +59,33 @@ class TestMailOAuth(
BASE_URL="/paperless/", BASE_URL="/paperless/",
): ):
self.assertEqual( self.assertEqual(
get_oauth_callback_url(), oauth_manager.oauth_callback_url,
"http://paperless.example.com/paperless/api/oauth/callback/", "http://paperless.example.com/paperless/api/oauth/callback/",
) )
# Redirect URL # Redirect URL
with override_settings(DEBUG=True): with override_settings(DEBUG=True):
self.assertEqual( self.assertEqual(
get_oauth_redirect_url(), oauth_manager.oauth_redirect_url,
"http://localhost:4200/mail", "http://localhost:4200/mail",
) )
with override_settings(DEBUG=False): with override_settings(DEBUG=False):
self.assertEqual( self.assertEqual(
get_oauth_redirect_url(), oauth_manager.oauth_redirect_url,
"/mail", "/mail",
) )
def test_generate_oauth_urls(self): @mock.patch(
""" "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_gmail_access_token",
GIVEN:
- Mocked settings for Gmail and Outlook OAuth client IDs
WHEN:
- generate_gmail_oauth_url and generate_outlook_oauth_url are called
THEN:
- Correct URLs are generated
"""
self.assertEqual(
"https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=test_gmail_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=https://mail.google.com/&access_type=offline&prompt=consent",
generate_gmail_oauth_url(),
) )
self.assertEqual( @mock.patch(
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=test_outlook_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=offline_access https://outlook.office.com/IMAP.AccessAsUser.All", "paperless_mail.oauth.PaperlessMailOAuth2Manager.get_outlook_access_token",
generate_outlook_oauth_url(),
) )
def test_oauth_callback_view(
@mock.patch("httpx.post") self,
def test_oauth_callback_view(self, mock_post): mock_get_outlook_access_token,
mock_get_gmail_access_token,
):
""" """
GIVEN: GIVEN:
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets - Mocked settings for Gmail and Outlook OAuth client IDs and secrets
@ -106,7 +95,12 @@ class TestMailOAuth(
- Gmail mail account is created - Gmail mail account is created
""" """
mock_post.return_value.json.return_value = { mock_get_gmail_access_token.return_value = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
}
mock_get_outlook_access_token.return_value = {
"access_token": "test_access_token", "access_token": "test_access_token",
"refresh_token": "test_refresh_token", "refresh_token": "test_refresh_token",
"expires_in": 3600, "expires_in": 3600,
@ -118,7 +112,7 @@ class TestMailOAuth(
) )
self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("oauth_success=1", response.url) self.assertIn("oauth_success=1", response.url)
mock_post.assert_called_once() mock_get_gmail_access_token.assert_called_once()
self.assertTrue( self.assertTrue(
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
) )
@ -152,37 +146,51 @@ class TestMailOAuth(
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
) )
@mock.patch("httpx.post") @mock.patch("httpx_oauth.oauth2.BaseOAuth2.get_access_token")
def test_oauth_callback_view_error(self, mock_post): def test_oauth_callback_view_error(self, mock_get_access_token):
""" """
GIVEN: GIVEN:
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets - Mocked settings for Gmail and Outlook OAuth client IDs and secrets
WHEN: WHEN:
- OAuth callback is called with an error - OAuth callback is called with an error
THEN: THEN:
- No mail account is created
- Error is logged - Error is logged
""" """
mock_get_access_token.side_effect = GetAccessTokenError("test_error")
mock_post.return_value.json.return_value = { with self.assertLogs("paperless_mail", level="ERROR") as cm:
"error": "test_error", # Test Google OAuth callback
}
response = self.client.get( response = self.client.get(
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/", "/api/oauth/callback/?code=test_code&scope=https://mail.google.com/",
) )
self.assertEqual(response.status_code, status.HTTP_302_FOUND) self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("oauth_success=0", response.url) self.assertIn("oauth_success=0", response.url)
mock_post.assert_called_once()
self.assertFalse( self.assertFalse(
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(), MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
) )
# Test Outlook OAuth callback
response = self.client.get("/api/oauth/callback/?code=test_code")
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("oauth_success=0", response.url)
self.assertFalse( self.assertFalse(
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(), MailAccount.objects.filter(
imap_server="outlook.office365.com",
).exists(),
) )
self.assertIn("Error getting access token: test_error", cm.output[0])
@mock.patch("paperless_mail.mail.get_mailbox") @mock.patch("paperless_mail.mail.get_mailbox")
@mock.patch("httpx.post") @mock.patch(
def test_refresh_token_on_handle_mail_account(self, mock_post, mock_get_mailbox): "paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token",
)
def test_refresh_token_on_handle_mail_account(
self,
mock_refresh_account_oauth_token,
mock_get_mailbox,
):
""" """
GIVEN: GIVEN:
- Mail account with refresh token and expiration - Mail account with refresh token and expiration
@ -206,16 +214,13 @@ class TestMailOAuth(
expiration=timezone.now() - timedelta(days=1), expiration=timezone.now() - timedelta(days=1),
) )
mock_post.return_value.status_code = 200 mock_refresh_account_oauth_token.return_value = True
mock_post.return_value.json.return_value = {
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
}
self.mail_account_handler.handle_mail_account(mail_account) self.mail_account_handler.handle_mail_account(mail_account)
mock_post.assert_called_once() mock_refresh_account_oauth_token.assert_called_once()
mock_refresh_account_oauth_token.reset_mock()
mock_refresh_account_oauth_token.return_value = True
outlook_mail_account = MailAccount.objects.create( outlook_mail_account = MailAccount.objects.create(
name="Test Outlook Mail Account", name="Test Outlook Mail Account",
username="test_username", username="test_username",
@ -228,13 +233,15 @@ class TestMailOAuth(
) )
self.mail_account_handler.handle_mail_account(outlook_mail_account) self.mail_account_handler.handle_mail_account(outlook_mail_account)
mock_post.assert_called() mock_refresh_account_oauth_token.assert_called_once()
@mock.patch("paperless_mail.mail.get_mailbox") @mock.patch("paperless_mail.mail.get_mailbox")
@mock.patch("httpx.post") @mock.patch(
"paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token",
)
def test_refresh_token_on_handle_mail_account_fails( def test_refresh_token_on_handle_mail_account_fails(
self, self,
mock_post, mock_refresh_account_oauth_token,
mock_get_mailbox, mock_get_mailbox,
): ):
""" """
@ -261,37 +268,10 @@ class TestMailOAuth(
expiration=timezone.now() - timedelta(days=1), expiration=timezone.now() - timedelta(days=1),
) )
mock_post.return_value.status_code = 400 mock_refresh_account_oauth_token.return_value = False
mock_post.return_value.json.return_value = {
"error": "test_error",
}
self.assertEqual( self.assertEqual(
self.mail_account_handler.handle_mail_account(mail_account), self.mail_account_handler.handle_mail_account(mail_account),
0, 0,
) )
mock_post.assert_called_once() mock_refresh_account_oauth_token.assert_called_once()
def test_refresh_token_invalid_account(self):
"""
GIVEN:
- Mail account without refresh token
WHEN:
- refresh_oauth_token is called
THEN:
- False is returned
"""
mail_account = MailAccount.objects.create(
name="test_mail_account",
username="test_username",
imap_security=MailAccount.ImapSecurity.SSL,
imap_port=993,
account_type=MailAccount.MailAccountType.GMAIL_OAUTH,
is_token=True,
expiration=timezone.now() - timedelta(days=1),
)
self.assertFalse(
refresh_oauth_token(mail_account),
)

View File

@ -2,10 +2,10 @@ import datetime
import logging import logging
from datetime import timedelta from datetime import timedelta
import httpx
from django.http import HttpResponseBadRequest from django.http import HttpResponseBadRequest
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.utils import timezone from django.utils import timezone
from httpx_oauth.oauth2 import GetAccessTokenError
from rest_framework.generics import GenericAPIView from rest_framework.generics import GenericAPIView
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
@ -20,12 +20,7 @@ from paperless_mail.mail import get_mailbox
from paperless_mail.mail import mailbox_login from paperless_mail.mail import mailbox_login
from paperless_mail.models import MailAccount from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule from paperless_mail.models import MailRule
from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN from paperless_mail.oauth import PaperlessMailOAuth2Manager
from paperless_mail.oauth import OUTLOOK_OAUTH_ENDPOINT_TOKEN
from paperless_mail.oauth import generate_gmail_oauth_token_request_data
from paperless_mail.oauth import generate_outlook_oauth_token_request_data
from paperless_mail.oauth import get_oauth_redirect_url
from paperless_mail.oauth import refresh_oauth_token
from paperless_mail.serialisers import MailAccountSerializer from paperless_mail.serialisers import MailAccountSerializer
from paperless_mail.serialisers import MailRuleSerializer from paperless_mail.serialisers import MailRuleSerializer
@ -83,7 +78,8 @@ class MailAccountTestView(GenericAPIView):
and account.expiration is not None and account.expiration is not None
and account.expiration < timezone.now() and account.expiration < timezone.now()
): ):
if refresh_oauth_token(existing_account): oauth_manager = PaperlessMailOAuth2Manager()
if oauth_manager.refresh_account_oauth_token(existing_account):
# User is not changing password and token needs to be refreshed # User is not changing password and token needs to be refreshed
existing_account.refresh_from_db() existing_account.refresh_from_db()
account.password = existing_account.password account.password = existing_account.password
@ -114,6 +110,9 @@ class OauthCallbackView(GenericAPIView):
) )
return HttpResponseBadRequest("Invalid request, see logs for more detail") return HttpResponseBadRequest("Invalid request, see logs for more detail")
oauth_manager = PaperlessMailOAuth2Manager()
try:
if scope is not None and "google" in scope: if scope is not None and "google" in scope:
# Google # Google
account_type = MailAccount.MailAccountType.GMAIL_OAUTH account_type = MailAccount.MailAccountType.GMAIL_OAUTH
@ -125,8 +124,7 @@ class OauthCallbackView(GenericAPIView):
"imap_port": 993, "imap_port": 993,
"account_type": account_type, "account_type": account_type,
} }
token_request_uri = GMAIL_OAUTH_ENDPOINT_TOKEN result = oauth_manager.get_gmail_access_token(code)
data = generate_gmail_oauth_token_request_data(code)
elif scope is None: elif scope is None:
# Outlook # Outlook
@ -140,24 +138,11 @@ class OauthCallbackView(GenericAPIView):
"account_type": account_type, "account_type": account_type,
} }
token_request_uri = OUTLOOK_OAUTH_ENDPOINT_TOKEN result = oauth_manager.get_outlook_access_token(code)
data = generate_outlook_oauth_token_request_data(code)
headers = { access_token = result["access_token"]
"Content-Type": "application/x-www-form-urlencoded", refresh_token = result["refresh_token"]
} expires_in = result["expires_in"]
response = httpx.post(token_request_uri, data=data, headers=headers)
data = response.json()
if "error" in data:
logger.error(f"Error {response.status_code} getting access token: {data}")
return HttpResponseRedirect(
f"{get_oauth_redirect_url()}?oauth_success=0",
)
elif "access_token" in data:
access_token = data["access_token"]
refresh_token = data["refresh_token"]
expires_in = data["expires_in"]
account, _ = MailAccount.objects.update_or_create( account, _ = MailAccount.objects.update_or_create(
password=access_token, password=access_token,
is_token=True, is_token=True,
@ -167,5 +152,10 @@ class OauthCallbackView(GenericAPIView):
defaults=defaults, defaults=defaults,
) )
return HttpResponseRedirect( return HttpResponseRedirect(
f"{get_oauth_redirect_url()}?oauth_success=1&account_id={account.pk}", f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}",
)
except GetAccessTokenError as e:
logger.error(f"Error getting access token: {e}")
return HttpResponseRedirect(
f"{oauth_manager.oauth_redirect_url}?oauth_success=0",
) )