Migrate to httpx_oauth
This commit is contained in:
parent
ad483b49a1
commit
855543377c
1
Pipfile
1
Pipfile
@ -58,6 +58,7 @@ whitenoise = "~=6.7"
|
|||||||
whoosh = "~=2.7"
|
whoosh = "~=2.7"
|
||||||
zxing-cpp = {version = "*", platform_machine = "== 'x86_64'"}
|
zxing-cpp = {version = "*", platform_machine = "== 'x86_64'"}
|
||||||
jinja2 = "~=3.1"
|
jinja2 = "~=3.1"
|
||||||
|
httpx-oauth = "*"
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
# Linting
|
# Linting
|
||||||
|
13
Pipfile.lock
generated
13
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "1e113d0879e4e0bc3c384115057647ac8d9be05252dd7c708a1fc873f294ef28"
|
"sha256": "584249cbeaf29659c975000b5e02b12e45d768d795e4a8ac36118e73bd7c0b8a"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {},
|
"requires": {},
|
||||||
@ -799,9 +799,18 @@
|
|||||||
"sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0",
|
"sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0",
|
||||||
"sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"
|
"sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"
|
||||||
],
|
],
|
||||||
"markers": "python_version >= '3.9'",
|
"markers": "python_version >= '3.8'",
|
||||||
"version": "==0.27.2"
|
"version": "==0.27.2"
|
||||||
},
|
},
|
||||||
|
"httpx-oauth": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:4094cf0938fc7252b5f5dfd62cd1ab5aee2fcb6734e621942ee17d1af4806b74",
|
||||||
|
"sha256:89b45f250e93e42bbe9631adf349cab0e3d3ced958c07e06651735198d1bdf00"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"markers": "python_version >= '3.8'",
|
||||||
|
"version": "==0.15.1"
|
||||||
|
},
|
||||||
"humanize": {
|
"humanize": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:06b6eb0293e4b85e8d385397c5868926820db32b9b654b932f57fa41c23c9978",
|
"sha256:06b6eb0293e4b85e8d385397c5868926820db32b9b654b932f57fa41c23c9978",
|
||||||
|
@ -160,8 +160,7 @@ from paperless.serialisers import UserSerializer
|
|||||||
from paperless.views import StandardPagination
|
from paperless.views import StandardPagination
|
||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.models import MailRule
|
from paperless_mail.models import MailRule
|
||||||
from paperless_mail.oauth import generate_gmail_oauth_url
|
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||||
from paperless_mail.oauth import generate_outlook_oauth_url
|
|
||||||
from paperless_mail.serialisers import MailAccountSerializer
|
from paperless_mail.serialisers import MailAccountSerializer
|
||||||
from paperless_mail.serialisers import MailRuleSerializer
|
from paperless_mail.serialisers import MailRuleSerializer
|
||||||
|
|
||||||
@ -1586,11 +1585,14 @@ class UiSettingsView(GenericAPIView):
|
|||||||
|
|
||||||
ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED
|
ui_settings["auditlog_enabled"] = settings.AUDIT_LOG_ENABLED
|
||||||
|
|
||||||
|
if settings.GMAIL_OAUTH_ENABLED or settings.OUTLOOK_OAUTH_ENABLED:
|
||||||
|
manager = PaperlessMailOAuth2Manager()
|
||||||
if settings.GMAIL_OAUTH_ENABLED:
|
if settings.GMAIL_OAUTH_ENABLED:
|
||||||
ui_settings["gmail_oauth_url"] = generate_gmail_oauth_url()
|
ui_settings["gmail_oauth_url"] = manager.get_gmail_authorization_url()
|
||||||
|
|
||||||
if settings.OUTLOOK_OAUTH_ENABLED:
|
if settings.OUTLOOK_OAUTH_ENABLED:
|
||||||
ui_settings["outlook_oauth_url"] = generate_outlook_oauth_url()
|
ui_settings["outlook_oauth_url"] = (
|
||||||
|
manager.get_outlook_authorization_url()
|
||||||
|
)
|
||||||
|
|
||||||
user_resp = {
|
user_resp = {
|
||||||
"id": user.id,
|
"id": user.id,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
@ -43,7 +44,7 @@ from documents.tasks import consume_file
|
|||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.models import MailRule
|
from paperless_mail.models import MailRule
|
||||||
from paperless_mail.models import ProcessedMail
|
from paperless_mail.models import ProcessedMail
|
||||||
from paperless_mail.oauth import refresh_oauth_token
|
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||||
from paperless_mail.preprocessor import MailMessageDecryptor
|
from paperless_mail.preprocessor import MailMessageDecryptor
|
||||||
from paperless_mail.preprocessor import MailMessagePreprocessor
|
from paperless_mail.preprocessor import MailMessagePreprocessor
|
||||||
|
|
||||||
@ -537,7 +538,8 @@ class MailAccountHandler(LoggingMixin):
|
|||||||
and account.expiration is not None
|
and account.expiration is not None
|
||||||
and account.expiration < timezone.now()
|
and account.expiration < timezone.now()
|
||||||
):
|
):
|
||||||
if refresh_oauth_token(account):
|
manager = PaperlessMailOAuth2Manager()
|
||||||
|
if asyncio.run(manager.refresh_account_oauth_token(account)):
|
||||||
account.refresh_from_db()
|
account.refresh_from_db()
|
||||||
else:
|
else:
|
||||||
return total_processed_files
|
return total_processed_files
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
import httpx
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from httpx_oauth.oauth2 import OAuth2
|
||||||
|
from httpx_oauth.oauth2 import OAuth2Token
|
||||||
|
from httpx_oauth.oauth2 import RefreshTokenError
|
||||||
|
|
||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
|
|
||||||
|
|
||||||
|
class PaperlessMailOAuth2Manager:
|
||||||
GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token"
|
GMAIL_OAUTH_ENDPOINT_TOKEN = "https://accounts.google.com/o/oauth2/token"
|
||||||
GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth"
|
GMAIL_OAUTH_ENDPOINT_AUTH = "https://accounts.google.com/o/oauth2/auth"
|
||||||
OUTLOOK_OAUTH_ENDPOINT_TOKEN = (
|
OUTLOOK_OAUTH_ENDPOINT_TOKEN = (
|
||||||
@ -16,65 +21,82 @@ OUTLOOK_OAUTH_ENDPOINT_AUTH = (
|
|||||||
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._gmail_client = None
|
||||||
|
self._outlook_client = None
|
||||||
|
|
||||||
def get_oauth_callback_url() -> str:
|
@property
|
||||||
|
def gmail_client(self) -> OAuth2:
|
||||||
|
if self._gmail_client is None:
|
||||||
|
self._gmail_client = OAuth2(
|
||||||
|
settings.GMAIL_OAUTH_CLIENT_ID,
|
||||||
|
settings.GMAIL_OAUTH_CLIENT_SECRET,
|
||||||
|
self.GMAIL_OAUTH_ENDPOINT_AUTH,
|
||||||
|
self.GMAIL_OAUTH_ENDPOINT_TOKEN,
|
||||||
|
refresh_token_endpoint=self.GMAIL_OAUTH_ENDPOINT_TOKEN,
|
||||||
|
token_endpoint_auth_method="client_secret_post",
|
||||||
|
)
|
||||||
|
return self._gmail_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outlook_client(self) -> OAuth2:
|
||||||
|
if self._outlook_client is None:
|
||||||
|
self._outlook_client = OAuth2(
|
||||||
|
settings.OUTLOOK_OAUTH_CLIENT_ID,
|
||||||
|
settings.OUTLOOK_OAUTH_CLIENT_SECRET,
|
||||||
|
self.OUTLOOK_OAUTH_ENDPOINT_AUTH,
|
||||||
|
self.OUTLOOK_OAUTH_ENDPOINT_TOKEN,
|
||||||
|
refresh_token_endpoint=self.OUTLOOK_OAUTH_ENDPOINT_TOKEN,
|
||||||
|
token_endpoint_auth_method="client_secret_post",
|
||||||
|
)
|
||||||
|
return self._outlook_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oauth_callback_url(self) -> str:
|
||||||
return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/"
|
return f"{settings.OAUTH_CALLBACK_BASE_URL if settings.OAUTH_CALLBACK_BASE_URL is not None else settings.PAPERLESS_URL}{settings.BASE_URL}api/oauth/callback/"
|
||||||
|
|
||||||
|
@property
|
||||||
def get_oauth_redirect_url() -> str:
|
def oauth_redirect_url(self) -> str:
|
||||||
return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail"
|
return f"{'http://localhost:4200/' if settings.DEBUG else settings.BASE_URL}mail" # e.g. "http://localhost:4200/mail" or "/mail"
|
||||||
|
|
||||||
|
def get_gmail_authorization_url(self) -> str:
|
||||||
|
return asyncio.run(
|
||||||
|
self.gmail_client.get_authorization_url(
|
||||||
|
redirect_uri=self.oauth_callback_url,
|
||||||
|
scope=["https://mail.google.com/"],
|
||||||
|
extras_params={"prompt": "consent", "access_type": "offline"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def generate_gmail_oauth_url() -> str:
|
def get_outlook_authorization_url(self) -> str:
|
||||||
response_type = "code"
|
return asyncio.run(
|
||||||
client_id = settings.GMAIL_OAUTH_CLIENT_ID
|
self.outlook_client.get_authorization_url(
|
||||||
redirect_uri = get_oauth_callback_url()
|
redirect_uri=self.oauth_callback_url,
|
||||||
scope = "https://mail.google.com/"
|
scope=[
|
||||||
access_type = "offline"
|
"offline_access",
|
||||||
url = f"{GMAIL_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}&access_type={access_type}&prompt=consent"
|
"https://outlook.office.com/IMAP.AccessAsUser.All",
|
||||||
return url
|
],
|
||||||
|
extras_params={"response_type": "code"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_gmail_access_token(self, code: str) -> OAuth2Token:
|
||||||
|
return asyncio.run(
|
||||||
|
self.gmail_client.get_access_token(
|
||||||
|
code=code,
|
||||||
|
redirect_uri=self.oauth_callback_url,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def generate_outlook_oauth_url() -> str:
|
def get_outlook_access_token(self, code: str) -> OAuth2Token:
|
||||||
response_type = "code"
|
return asyncio.run(
|
||||||
client_id = settings.OUTLOOK_OAUTH_CLIENT_ID
|
self.outlook_client.get_access_token(
|
||||||
redirect_uri = get_oauth_callback_url()
|
code=code,
|
||||||
scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All"
|
redirect_uri=self.oauth_callback_url,
|
||||||
url = f"{OUTLOOK_OAUTH_ENDPOINT_AUTH}?response_type={response_type}&response_mode=query&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}"
|
),
|
||||||
return url
|
)
|
||||||
|
|
||||||
|
def refresh_account_oauth_token(self, account: MailAccount) -> bool:
|
||||||
def generate_gmail_oauth_token_request_data(code: str) -> dict:
|
|
||||||
client_id = settings.GMAIL_OAUTH_CLIENT_ID
|
|
||||||
client_secret = settings.GMAIL_OAUTH_CLIENT_SECRET
|
|
||||||
scope = "https://mail.google.com/"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code": code,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
"scope": scope,
|
|
||||||
"redirect_uri": get_oauth_callback_url(),
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_outlook_oauth_token_request_data(code: str) -> dict:
|
|
||||||
client_id = settings.OUTLOOK_OAUTH_CLIENT_ID
|
|
||||||
client_secret = settings.OUTLOOK_OAUTH_CLIENT_SECRET
|
|
||||||
scope = "offline_access https://outlook.office.com/IMAP.AccessAsUser.All"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code": code,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
"scope": scope,
|
|
||||||
"redirect_uri": get_oauth_callback_url(),
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_oauth_token(account: MailAccount) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Refreshes the oauth token for the given mail account.
|
Refreshes the oauth token for the given mail account.
|
||||||
"""
|
"""
|
||||||
@ -84,39 +106,27 @@ def refresh_oauth_token(account: MailAccount) -> bool:
|
|||||||
logger.error(f"Account {account}: No refresh token available.")
|
logger.error(f"Account {account}: No refresh token available.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
result: OAuth2Token
|
||||||
if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH:
|
if account.account_type == MailAccount.MailAccountType.GMAIL_OAUTH:
|
||||||
url = GMAIL_OAUTH_ENDPOINT_TOKEN
|
result = asyncio.run(
|
||||||
data = {
|
self.gmail_client.refresh_token(
|
||||||
"client_id": settings.GMAIL_OAUTH_CLIENT_ID,
|
refresh_token=account.refresh_token,
|
||||||
"client_secret": settings.GMAIL_OAUTH_CLIENT_SECRET,
|
),
|
||||||
"refresh_token": account.refresh_token,
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
}
|
|
||||||
elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH:
|
|
||||||
url = OUTLOOK_OAUTH_ENDPOINT_TOKEN
|
|
||||||
data = {
|
|
||||||
"client_id": settings.OUTLOOK_OAUTH_CLIENT_ID,
|
|
||||||
"client_secret": settings.OUTLOOK_OAUTH_CLIENT_SECRET,
|
|
||||||
"refresh_token": account.refresh_token,
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
}
|
|
||||||
|
|
||||||
response = httpx.post(
|
|
||||||
url=url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
)
|
||||||
data = response.json()
|
elif account.account_type == MailAccount.MailAccountType.OUTLOOK_OAUTH:
|
||||||
if response.status_code < 400 and "access_token" in data:
|
result = asyncio.run(
|
||||||
account.password = data["access_token"]
|
self.outlook_client.refresh_token(
|
||||||
|
refresh_token=account.refresh_token,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
account.password = result["access_token"]
|
||||||
account.expiration = timezone.now() + timedelta(
|
account.expiration = timezone.now() + timedelta(
|
||||||
seconds=data["expires_in"],
|
seconds=result["expires_in"],
|
||||||
)
|
)
|
||||||
account.save()
|
account.save()
|
||||||
logger.debug(f"Successfully refreshed oauth token for account {account}")
|
logger.debug(f"Successfully refreshed oauth token for account {account}")
|
||||||
return True
|
return True
|
||||||
else:
|
except RefreshTokenError as e:
|
||||||
logger.error(
|
logger.error(f"Failed to refresh oauth token for account {account}: {e}")
|
||||||
f"Failed to refresh oauth token for account {account}: {response}",
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
@ -35,7 +35,6 @@ from paperless_mail.mail import apply_mail_action
|
|||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.models import MailRule
|
from paperless_mail.models import MailRule
|
||||||
from paperless_mail.models import ProcessedMail
|
from paperless_mail.models import ProcessedMail
|
||||||
from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -1636,11 +1635,19 @@ class TestMailAccountTestView(APITestCase):
|
|||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
self.assertEqual(response.content.decode(), "Unable to connect to server")
|
self.assertEqual(response.content.decode(), "Unable to connect to server")
|
||||||
|
|
||||||
@mock.patch("httpx.post")
|
@mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager")
|
||||||
def test_mail_account_test_view_refresh_token(
|
def test_mail_account_test_view_refresh_token(
|
||||||
self,
|
self,
|
||||||
mock_post,
|
mock_manager,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Mail account with expired token
|
||||||
|
WHEN:
|
||||||
|
- Mail account is tested
|
||||||
|
THEN:
|
||||||
|
- Should refresh the token
|
||||||
|
"""
|
||||||
existing_account = MailAccount.objects.create(
|
existing_account = MailAccount.objects.create(
|
||||||
imap_server="imap.example.com",
|
imap_server="imap.example.com",
|
||||||
imap_port=993,
|
imap_port=993,
|
||||||
@ -1653,11 +1660,7 @@ class TestMailAccountTestView(APITestCase):
|
|||||||
is_token=True,
|
is_token=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_post.return_value.status_code = status.HTTP_200_OK
|
mock_manager.return_value.refresh_account_oauth_token.return_value = True
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"access_token": "newtoken",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
data = {
|
data = {
|
||||||
"id": existing_account.id,
|
"id": existing_account.id,
|
||||||
"imap_server": "imap.example.com",
|
"imap_server": "imap.example.com",
|
||||||
@ -1668,13 +1671,22 @@ class TestMailAccountTestView(APITestCase):
|
|||||||
"is_token": True,
|
"is_token": True,
|
||||||
}
|
}
|
||||||
self.client.post(self.url, data, format="json")
|
self.client.post(self.url, data, format="json")
|
||||||
self.assertEqual(mock_post.call_args[1]["url"], GMAIL_OAUTH_ENDPOINT_TOKEN)
|
self.assertEqual(mock_manager.call_count, 1)
|
||||||
|
|
||||||
@mock.patch("httpx.post")
|
@mock.patch("paperless_mail.oauth.PaperlessMailOAuth2Manager")
|
||||||
def test_mail_account_test_view_refresh_token_fails(
|
def test_mail_account_test_view_refresh_token_fails(
|
||||||
self,
|
self,
|
||||||
mock_post,
|
mock_manager,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
GIVEN:
|
||||||
|
- Mail account with expired token
|
||||||
|
WHEN:
|
||||||
|
- Mail account is tested
|
||||||
|
- Token refresh fails
|
||||||
|
THEN:
|
||||||
|
- Should log an error
|
||||||
|
"""
|
||||||
existing_account = MailAccount.objects.create(
|
existing_account = MailAccount.objects.create(
|
||||||
imap_server="imap.example.com",
|
imap_server="imap.example.com",
|
||||||
imap_port=993,
|
imap_port=993,
|
||||||
@ -1687,10 +1699,7 @@ class TestMailAccountTestView(APITestCase):
|
|||||||
is_token=True,
|
is_token=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_post.return_value.status_code = status.HTTP_400_BAD_REQUEST
|
mock_manager.return_value.refresh_account_oauth_token.return_value = False
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"error": "invalid_grant",
|
|
||||||
}
|
|
||||||
data = {
|
data = {
|
||||||
"id": existing_account.id,
|
"id": existing_account.id,
|
||||||
"imap_server": "imap.example.com",
|
"imap_server": "imap.example.com",
|
||||||
@ -1704,5 +1713,5 @@ class TestMailAccountTestView(APITestCase):
|
|||||||
response = self.client.post(self.url, data, format="json")
|
response = self.client.post(self.url, data, format="json")
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||||
error_str = cm.output[0]
|
error_str = cm.output[0]
|
||||||
expected_str = "Failed to refresh oauth token for account"
|
expected_str = "Unable to refresh oauth token"
|
||||||
self.assertIn(expected_str, error_str)
|
self.assertIn(expected_str, error_str)
|
||||||
|
@ -6,15 +6,12 @@ from django.contrib.auth.models import User
|
|||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
from paperless_mail.mail import MailAccountHandler
|
from paperless_mail.mail import MailAccountHandler
|
||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.oauth import generate_gmail_oauth_url
|
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||||
from paperless_mail.oauth import generate_outlook_oauth_url
|
|
||||||
from paperless_mail.oauth import get_oauth_callback_url
|
|
||||||
from paperless_mail.oauth import get_oauth_redirect_url
|
|
||||||
from paperless_mail.oauth import refresh_oauth_token
|
|
||||||
|
|
||||||
|
|
||||||
class TestMailOAuth(
|
class TestMailOAuth(
|
||||||
@ -42,9 +39,10 @@ class TestMailOAuth(
|
|||||||
- Correct URLs are generated
|
- Correct URLs are generated
|
||||||
"""
|
"""
|
||||||
# Callback URL
|
# Callback URL
|
||||||
|
oauth_manager = PaperlessMailOAuth2Manager()
|
||||||
with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"):
|
with override_settings(OAUTH_CALLBACK_BASE_URL="http://paperless.example.com"):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_oauth_callback_url(),
|
oauth_manager.oauth_callback_url,
|
||||||
"http://paperless.example.com/api/oauth/callback/",
|
"http://paperless.example.com/api/oauth/callback/",
|
||||||
)
|
)
|
||||||
with override_settings(
|
with override_settings(
|
||||||
@ -52,7 +50,7 @@ class TestMailOAuth(
|
|||||||
PAPERLESS_URL="http://paperless.example.com",
|
PAPERLESS_URL="http://paperless.example.com",
|
||||||
):
|
):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_oauth_callback_url(),
|
oauth_manager.oauth_callback_url,
|
||||||
"http://paperless.example.com/api/oauth/callback/",
|
"http://paperless.example.com/api/oauth/callback/",
|
||||||
)
|
)
|
||||||
with override_settings(
|
with override_settings(
|
||||||
@ -61,42 +59,33 @@ class TestMailOAuth(
|
|||||||
BASE_URL="/paperless/",
|
BASE_URL="/paperless/",
|
||||||
):
|
):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_oauth_callback_url(),
|
oauth_manager.oauth_callback_url,
|
||||||
"http://paperless.example.com/paperless/api/oauth/callback/",
|
"http://paperless.example.com/paperless/api/oauth/callback/",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Redirect URL
|
# Redirect URL
|
||||||
with override_settings(DEBUG=True):
|
with override_settings(DEBUG=True):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_oauth_redirect_url(),
|
oauth_manager.oauth_redirect_url,
|
||||||
"http://localhost:4200/mail",
|
"http://localhost:4200/mail",
|
||||||
)
|
)
|
||||||
with override_settings(DEBUG=False):
|
with override_settings(DEBUG=False):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_oauth_redirect_url(),
|
oauth_manager.oauth_redirect_url,
|
||||||
"/mail",
|
"/mail",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_generate_oauth_urls(self):
|
@mock.patch(
|
||||||
"""
|
"paperless_mail.oauth.PaperlessMailOAuth2Manager.get_gmail_access_token",
|
||||||
GIVEN:
|
|
||||||
- Mocked settings for Gmail and Outlook OAuth client IDs
|
|
||||||
WHEN:
|
|
||||||
- generate_gmail_oauth_url and generate_outlook_oauth_url are called
|
|
||||||
THEN:
|
|
||||||
- Correct URLs are generated
|
|
||||||
"""
|
|
||||||
self.assertEqual(
|
|
||||||
"https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=test_gmail_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=https://mail.google.com/&access_type=offline&prompt=consent",
|
|
||||||
generate_gmail_oauth_url(),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
@mock.patch(
|
||||||
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=test_outlook_client_id&redirect_uri=http://localhost:8000/api/oauth/callback/&scope=offline_access https://outlook.office.com/IMAP.AccessAsUser.All",
|
"paperless_mail.oauth.PaperlessMailOAuth2Manager.get_outlook_access_token",
|
||||||
generate_outlook_oauth_url(),
|
|
||||||
)
|
)
|
||||||
|
def test_oauth_callback_view(
|
||||||
@mock.patch("httpx.post")
|
self,
|
||||||
def test_oauth_callback_view(self, mock_post):
|
mock_get_outlook_access_token,
|
||||||
|
mock_get_gmail_access_token,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
|
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
|
||||||
@ -106,7 +95,12 @@ class TestMailOAuth(
|
|||||||
- Gmail mail account is created
|
- Gmail mail account is created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mock_post.return_value.json.return_value = {
|
mock_get_gmail_access_token.return_value = {
|
||||||
|
"access_token": "test_access_token",
|
||||||
|
"refresh_token": "test_refresh_token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
mock_get_outlook_access_token.return_value = {
|
||||||
"access_token": "test_access_token",
|
"access_token": "test_access_token",
|
||||||
"refresh_token": "test_refresh_token",
|
"refresh_token": "test_refresh_token",
|
||||||
"expires_in": 3600,
|
"expires_in": 3600,
|
||||||
@ -118,7 +112,7 @@ class TestMailOAuth(
|
|||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=1", response.url)
|
self.assertIn("oauth_success=1", response.url)
|
||||||
mock_post.assert_called_once()
|
mock_get_gmail_access_token.assert_called_once()
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
|
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
|
||||||
)
|
)
|
||||||
@ -152,37 +146,51 @@ class TestMailOAuth(
|
|||||||
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@mock.patch("httpx.post")
|
@mock.patch("httpx_oauth.oauth2.BaseOAuth2.get_access_token")
|
||||||
def test_oauth_callback_view_error(self, mock_post):
|
def test_oauth_callback_view_error(self, mock_get_access_token):
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
|
- Mocked settings for Gmail and Outlook OAuth client IDs and secrets
|
||||||
WHEN:
|
WHEN:
|
||||||
- OAuth callback is called with an error
|
- OAuth callback is called with an error
|
||||||
THEN:
|
THEN:
|
||||||
|
- No mail account is created
|
||||||
- Error is logged
|
- Error is logged
|
||||||
"""
|
"""
|
||||||
|
mock_get_access_token.side_effect = GetAccessTokenError("test_error")
|
||||||
|
|
||||||
mock_post.return_value.json.return_value = {
|
with self.assertLogs("paperless_mail", level="ERROR") as cm:
|
||||||
"error": "test_error",
|
# Test Google OAuth callback
|
||||||
}
|
|
||||||
|
|
||||||
response = self.client.get(
|
response = self.client.get(
|
||||||
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/",
|
"/api/oauth/callback/?code=test_code&scope=https://mail.google.com/",
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
self.assertIn("oauth_success=0", response.url)
|
self.assertIn("oauth_success=0", response.url)
|
||||||
mock_post.assert_called_once()
|
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
|
MailAccount.objects.filter(imap_server="imap.gmail.com").exists(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test Outlook OAuth callback
|
||||||
|
response = self.client.get("/api/oauth/callback/?code=test_code")
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
|
||||||
|
self.assertIn("oauth_success=0", response.url)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
MailAccount.objects.filter(imap_server="outlook.office365.com").exists(),
|
MailAccount.objects.filter(
|
||||||
|
imap_server="outlook.office365.com",
|
||||||
|
).exists(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertIn("Error getting access token: test_error", cm.output[0])
|
||||||
|
|
||||||
@mock.patch("paperless_mail.mail.get_mailbox")
|
@mock.patch("paperless_mail.mail.get_mailbox")
|
||||||
@mock.patch("httpx.post")
|
@mock.patch(
|
||||||
def test_refresh_token_on_handle_mail_account(self, mock_post, mock_get_mailbox):
|
"paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token",
|
||||||
|
)
|
||||||
|
def test_refresh_token_on_handle_mail_account(
|
||||||
|
self,
|
||||||
|
mock_refresh_account_oauth_token,
|
||||||
|
mock_get_mailbox,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
GIVEN:
|
GIVEN:
|
||||||
- Mail account with refresh token and expiration
|
- Mail account with refresh token and expiration
|
||||||
@ -206,16 +214,13 @@ class TestMailOAuth(
|
|||||||
expiration=timezone.now() - timedelta(days=1),
|
expiration=timezone.now() - timedelta(days=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_post.return_value.status_code = 200
|
mock_refresh_account_oauth_token.return_value = True
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"access_token": "test_access_token",
|
|
||||||
"refresh_token": "test_refresh_token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.mail_account_handler.handle_mail_account(mail_account)
|
self.mail_account_handler.handle_mail_account(mail_account)
|
||||||
mock_post.assert_called_once()
|
mock_refresh_account_oauth_token.assert_called_once()
|
||||||
|
mock_refresh_account_oauth_token.reset_mock()
|
||||||
|
|
||||||
|
mock_refresh_account_oauth_token.return_value = True
|
||||||
outlook_mail_account = MailAccount.objects.create(
|
outlook_mail_account = MailAccount.objects.create(
|
||||||
name="Test Outlook Mail Account",
|
name="Test Outlook Mail Account",
|
||||||
username="test_username",
|
username="test_username",
|
||||||
@ -228,13 +233,15 @@ class TestMailOAuth(
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.mail_account_handler.handle_mail_account(outlook_mail_account)
|
self.mail_account_handler.handle_mail_account(outlook_mail_account)
|
||||||
mock_post.assert_called()
|
mock_refresh_account_oauth_token.assert_called_once()
|
||||||
|
|
||||||
@mock.patch("paperless_mail.mail.get_mailbox")
|
@mock.patch("paperless_mail.mail.get_mailbox")
|
||||||
@mock.patch("httpx.post")
|
@mock.patch(
|
||||||
|
"paperless_mail.oauth.PaperlessMailOAuth2Manager.refresh_account_oauth_token",
|
||||||
|
)
|
||||||
def test_refresh_token_on_handle_mail_account_fails(
|
def test_refresh_token_on_handle_mail_account_fails(
|
||||||
self,
|
self,
|
||||||
mock_post,
|
mock_refresh_account_oauth_token,
|
||||||
mock_get_mailbox,
|
mock_get_mailbox,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -261,37 +268,10 @@ class TestMailOAuth(
|
|||||||
expiration=timezone.now() - timedelta(days=1),
|
expiration=timezone.now() - timedelta(days=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_post.return_value.status_code = 400
|
mock_refresh_account_oauth_token.return_value = False
|
||||||
mock_post.return_value.json.return_value = {
|
|
||||||
"error": "test_error",
|
|
||||||
}
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.mail_account_handler.handle_mail_account(mail_account),
|
self.mail_account_handler.handle_mail_account(mail_account),
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
mock_post.assert_called_once()
|
mock_refresh_account_oauth_token.assert_called_once()
|
||||||
|
|
||||||
def test_refresh_token_invalid_account(self):
|
|
||||||
"""
|
|
||||||
GIVEN:
|
|
||||||
- Mail account without refresh token
|
|
||||||
WHEN:
|
|
||||||
- refresh_oauth_token is called
|
|
||||||
THEN:
|
|
||||||
- False is returned
|
|
||||||
"""
|
|
||||||
|
|
||||||
mail_account = MailAccount.objects.create(
|
|
||||||
name="test_mail_account",
|
|
||||||
username="test_username",
|
|
||||||
imap_security=MailAccount.ImapSecurity.SSL,
|
|
||||||
imap_port=993,
|
|
||||||
account_type=MailAccount.MailAccountType.GMAIL_OAUTH,
|
|
||||||
is_token=True,
|
|
||||||
expiration=timezone.now() - timedelta(days=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertFalse(
|
|
||||||
refresh_oauth_token(mail_account),
|
|
||||||
)
|
|
||||||
|
@ -2,10 +2,10 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
import httpx
|
|
||||||
from django.http import HttpResponseBadRequest
|
from django.http import HttpResponseBadRequest
|
||||||
from django.http import HttpResponseRedirect
|
from django.http import HttpResponseRedirect
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||||
from rest_framework.generics import GenericAPIView
|
from rest_framework.generics import GenericAPIView
|
||||||
from rest_framework.permissions import IsAuthenticated
|
from rest_framework.permissions import IsAuthenticated
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
@ -20,12 +20,7 @@ from paperless_mail.mail import get_mailbox
|
|||||||
from paperless_mail.mail import mailbox_login
|
from paperless_mail.mail import mailbox_login
|
||||||
from paperless_mail.models import MailAccount
|
from paperless_mail.models import MailAccount
|
||||||
from paperless_mail.models import MailRule
|
from paperless_mail.models import MailRule
|
||||||
from paperless_mail.oauth import GMAIL_OAUTH_ENDPOINT_TOKEN
|
from paperless_mail.oauth import PaperlessMailOAuth2Manager
|
||||||
from paperless_mail.oauth import OUTLOOK_OAUTH_ENDPOINT_TOKEN
|
|
||||||
from paperless_mail.oauth import generate_gmail_oauth_token_request_data
|
|
||||||
from paperless_mail.oauth import generate_outlook_oauth_token_request_data
|
|
||||||
from paperless_mail.oauth import get_oauth_redirect_url
|
|
||||||
from paperless_mail.oauth import refresh_oauth_token
|
|
||||||
from paperless_mail.serialisers import MailAccountSerializer
|
from paperless_mail.serialisers import MailAccountSerializer
|
||||||
from paperless_mail.serialisers import MailRuleSerializer
|
from paperless_mail.serialisers import MailRuleSerializer
|
||||||
|
|
||||||
@ -83,7 +78,8 @@ class MailAccountTestView(GenericAPIView):
|
|||||||
and account.expiration is not None
|
and account.expiration is not None
|
||||||
and account.expiration < timezone.now()
|
and account.expiration < timezone.now()
|
||||||
):
|
):
|
||||||
if refresh_oauth_token(existing_account):
|
oauth_manager = PaperlessMailOAuth2Manager()
|
||||||
|
if oauth_manager.refresh_account_oauth_token(existing_account):
|
||||||
# User is not changing password and token needs to be refreshed
|
# User is not changing password and token needs to be refreshed
|
||||||
existing_account.refresh_from_db()
|
existing_account.refresh_from_db()
|
||||||
account.password = existing_account.password
|
account.password = existing_account.password
|
||||||
@ -114,6 +110,9 @@ class OauthCallbackView(GenericAPIView):
|
|||||||
)
|
)
|
||||||
return HttpResponseBadRequest("Invalid request, see logs for more detail")
|
return HttpResponseBadRequest("Invalid request, see logs for more detail")
|
||||||
|
|
||||||
|
oauth_manager = PaperlessMailOAuth2Manager()
|
||||||
|
|
||||||
|
try:
|
||||||
if scope is not None and "google" in scope:
|
if scope is not None and "google" in scope:
|
||||||
# Google
|
# Google
|
||||||
account_type = MailAccount.MailAccountType.GMAIL_OAUTH
|
account_type = MailAccount.MailAccountType.GMAIL_OAUTH
|
||||||
@ -125,8 +124,7 @@ class OauthCallbackView(GenericAPIView):
|
|||||||
"imap_port": 993,
|
"imap_port": 993,
|
||||||
"account_type": account_type,
|
"account_type": account_type,
|
||||||
}
|
}
|
||||||
token_request_uri = GMAIL_OAUTH_ENDPOINT_TOKEN
|
result = oauth_manager.get_gmail_access_token(code)
|
||||||
data = generate_gmail_oauth_token_request_data(code)
|
|
||||||
|
|
||||||
elif scope is None:
|
elif scope is None:
|
||||||
# Outlook
|
# Outlook
|
||||||
@ -140,24 +138,11 @@ class OauthCallbackView(GenericAPIView):
|
|||||||
"account_type": account_type,
|
"account_type": account_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
token_request_uri = OUTLOOK_OAUTH_ENDPOINT_TOKEN
|
result = oauth_manager.get_outlook_access_token(code)
|
||||||
data = generate_outlook_oauth_token_request_data(code)
|
|
||||||
|
|
||||||
headers = {
|
access_token = result["access_token"]
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
refresh_token = result["refresh_token"]
|
||||||
}
|
expires_in = result["expires_in"]
|
||||||
response = httpx.post(token_request_uri, data=data, headers=headers)
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if "error" in data:
|
|
||||||
logger.error(f"Error {response.status_code} getting access token: {data}")
|
|
||||||
return HttpResponseRedirect(
|
|
||||||
f"{get_oauth_redirect_url()}?oauth_success=0",
|
|
||||||
)
|
|
||||||
elif "access_token" in data:
|
|
||||||
access_token = data["access_token"]
|
|
||||||
refresh_token = data["refresh_token"]
|
|
||||||
expires_in = data["expires_in"]
|
|
||||||
account, _ = MailAccount.objects.update_or_create(
|
account, _ = MailAccount.objects.update_or_create(
|
||||||
password=access_token,
|
password=access_token,
|
||||||
is_token=True,
|
is_token=True,
|
||||||
@ -167,5 +152,10 @@ class OauthCallbackView(GenericAPIView):
|
|||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
)
|
)
|
||||||
return HttpResponseRedirect(
|
return HttpResponseRedirect(
|
||||||
f"{get_oauth_redirect_url()}?oauth_success=1&account_id={account.pk}",
|
f"{oauth_manager.oauth_redirect_url}?oauth_success=1&account_id={account.pk}",
|
||||||
|
)
|
||||||
|
except GetAccessTokenError as e:
|
||||||
|
logger.error(f"Error getting access token: {e}")
|
||||||
|
return HttpResponseRedirect(
|
||||||
|
f"{oauth_manager.oauth_redirect_url}?oauth_success=0",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user