diff --git a/src/documents/tests/test_api_search.py b/src/documents/tests/test_api_search.py index 1b46f8e33..486d35851 100644 --- a/src/documents/tests/test_api_search.py +++ b/src/documents/tests/test_api_search.py @@ -4,6 +4,7 @@ from unittest import mock import pytest from dateutil.relativedelta import relativedelta +from django.contrib.auth.models import Group from django.contrib.auth.models import Permission from django.contrib.auth.models import User from django.test import override_settings @@ -22,7 +23,10 @@ from documents.models import DocumentType from documents.models import Note from documents.models import StoragePath from documents.models import Tag +from documents.models import Workflow from documents.tests.utils import DirectoriesMixin +from paperless_mail.models import MailAccount +from paperless_mail.models import MailRule class TestDocumentSearchApi(DirectoriesMixin, APITestCase): @@ -1125,3 +1129,88 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase): search_query("&ordering=-owner"), [d3.id, d2.id, d1.id], ) + + def test_global_search(self): + """ + GIVEN: + - Multiple documents and objects + WHEN: + - Global search query is made + THEN: + - Appropriately filtered results are returned + """ + d1 = Document.objects.create( + title="invoice doc1", + content="the thing i bought at a shop and paid with bank account", + checksum="A", + pk=1, + ) + d2 = Document.objects.create( + title="bank statement doc2", + content="things i paid for in august", + checksum="B", + pk=2, + ) + d3 = Document.objects.create( + title="tax bill doc3", + content="no b word", + checksum="C", + pk=3, + ) + + with index.open_index_writer() as writer: + index.update_document(writer, d1) + index.update_document(writer, d2) + index.update_document(writer, d3) + + correspondent1 = Correspondent.objects.create(name="bank correspondent 1") + Correspondent.objects.create(name="correspondent 2") + document_type1 = DocumentType.objects.create(name="bank invoice") + DocumentType.objects.create(name="invoice") + storage_path1 = StoragePath.objects.create(name="bank path 1", path="path1") + StoragePath.objects.create(name="path 2", path="path2") + tag1 = Tag.objects.create(name="bank tag1") + Tag.objects.create(name="tag2") + user1 = User.objects.create_user("bank user1") + User.objects.create_user("user2") + group1 = Group.objects.create(name="bank group1") + Group.objects.create(name="group2") + mail_account1 = MailAccount.objects.create(name="bank mail account 1") + mail_account2 = MailAccount.objects.create(name="mail account 2") + mail_rule1 = MailRule.objects.create( + name="bank mail rule 1", + account=mail_account1, + action=MailRule.MailAction.MOVE, + ) + MailRule.objects.create( + name="mail rule 2", + account=mail_account2, + action=MailRule.MailAction.MOVE, + ) + custom_field1 = CustomField.objects.create( + name="bank custom field 1", + data_type=CustomField.FieldDataType.STRING, + ) + CustomField.objects.create( + name="custom field 2", + data_type=CustomField.FieldDataType.INT, + ) + workflow1 = Workflow.objects.create(name="bank workflow 1") + Workflow.objects.create(name="workflow 2") + + response = self.client.get("/api/search/?query=bank") + self.assertEqual(response.status_code, status.HTTP_200_OK) + results = response.data + self.assertEqual(len(results["documents"]), 2) + self.assertNotEqual(results["documents"][0]["id"], d3.id) + self.assertNotEqual(results["documents"][1]["id"], d3.id) + self.assertEqual(results["correspondents"][0]["id"], correspondent1.id) + self.assertEqual(results["document_types"][0]["id"], document_type1.id) + self.assertEqual(results["storage_paths"][0]["id"], storage_path1.id) + self.assertEqual(results["tags"][0]["id"], tag1.id) + self.assertEqual(results["users"][0]["id"], user1.id) + self.assertEqual(results["groups"][0]["id"], group1.id) + self.assertEqual(results["mail_accounts"][0]["id"], mail_account1.id) + self.assertEqual(results["mail_rules"][0]["id"], mail_rule1.id) + self.assertEqual(results["custom_fields"][0]["id"], custom_field1.id) + self.assertEqual(results["workflows"][0]["id"], workflow1.id) diff --git a/src/documents/views.py b/src/documents/views.py index d220d1aaa..d279c69df 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -17,6 +17,7 @@ from urllib.parse import urlparse import pathvalidate from django.apps import apps from django.conf import settings +from django.contrib.auth.models import Group from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.db import connections @@ -152,7 +153,13 @@ from paperless import version from paperless.celery import app as celery_app from paperless.config import GeneralConfig from paperless.db import GnuPG +from paperless.serialisers import GroupSerializer +from paperless.serialisers import UserSerializer from paperless.views import StandardPagination +from paperless_mail.models import MailAccount +from paperless_mail.models import MailRule +from paperless_mail.serialisers import MailAccountSerializer +from paperless_mail.serialisers import MailRuleSerializer if settings.AUDIT_LOG_ENABLED: from auditlog.models import LogEntry @@ -1119,6 +1126,98 @@ class SearchAutoCompleteView(APIView): ) +class GlobalSearchView(PassUserMixin): + permission_classes = (IsAuthenticated,) + serializer_class = SearchResultSerializer + + def get(self, request, *args, **kwargs): + query = request.query_params.get("query", None) + if query is None: + return HttpResponseBadRequest("Query required") + elif len(query) < 3: + return HttpResponseBadRequest("Query must be at least 3 characters") + + docs = [] + from documents import index + + with index.open_index_searcher() as s: + q, _ = index.DelayedFullTextQuery( + s, + request.query_params, + 10, + request.user, + )._get_query() + results = s.search(q, limit=10) + docs = Document.objects.filter(id__in=[r["id"] for r in results]) + + tags = Tag.objects.filter(name__contains=query) + correspondents = Correspondent.objects.filter(name__contains=query) + document_types = DocumentType.objects.filter(name__contains=query) + storage_paths = StoragePath.objects.filter(name__contains=query) + users = User.objects.filter(username__contains=query) + groups = Group.objects.filter(name__contains=query) + mail_rules = MailRule.objects.filter(name__contains=query) + mail_accounts = MailAccount.objects.filter(name__contains=query) + workflows = Workflow.objects.filter(name__contains=query) + custom_fields = CustomField.objects.filter(name__contains=query) + + context = { + "request": request, + } + + docs_serializer = DocumentSerializer(docs, many=True, context=context) + tags_serializer = TagSerializer(tags, many=True, context=context) + correspondents_serializer = CorrespondentSerializer( + correspondents, + many=True, + context=context, + ) + document_types_serializer = DocumentTypeSerializer( + document_types, + many=True, + context=context, + ) + storage_paths_serializer = StoragePathSerializer( + storage_paths, + many=True, + context=context, + ) + users_serializer = UserSerializer(users, many=True, context=context) + groups_serializer = GroupSerializer(groups, many=True, context=context) + mail_rules_serializer = MailRuleSerializer( + mail_rules, + many=True, + context=context, + ) + mail_accounts_serializer = MailAccountSerializer( + mail_accounts, + many=True, + context=context, + ) + workflows_serializer = WorkflowSerializer(workflows, many=True, context=context) + custom_fields_serializer = CustomFieldSerializer( + custom_fields, + many=True, + context=context, + ) + + return Response( + { + "documents": docs_serializer.data, + "tags": tags_serializer.data, + "correspondents": correspondents_serializer.data, + "document_types": document_types_serializer.data, + "storage_paths": storage_paths_serializer.data, + "users": users_serializer.data, + "groups": groups_serializer.data, + "mail_rules": mail_rules_serializer.data, + "mail_accounts": mail_accounts_serializer.data, + "workflows": workflows_serializer.data, + "custom_fields": custom_fields_serializer.data, + }, + ) + + class StatisticsView(APIView): permission_classes = (IsAuthenticated,) diff --git a/src/paperless/urls.py b/src/paperless/urls.py index 12b049918..8626cc8b1 100644 --- a/src/paperless/urls.py +++ b/src/paperless/urls.py @@ -21,6 +21,7 @@ from documents.views import BulkEditView from documents.views import CorrespondentViewSet from documents.views import CustomFieldViewSet from documents.views import DocumentTypeViewSet +from documents.views import GlobalSearchView from documents.views import IndexView from documents.views import LogViewSet from documents.views import PostDocumentView @@ -91,6 +92,11 @@ urlpatterns = [ SearchAutoCompleteView.as_view(), name="autocomplete", ), + re_path( + "^search/", + GlobalSearchView.as_view(), + name="global_search", + ), re_path("^statistics/", StatisticsView.as_view(), name="statistics"), re_path( "^documents/post_document/",