Basic global search api endpoint

This commit is contained in:
shamoon 2024-03-28 23:53:55 -07:00
parent 7a0334f353
commit b7a0c0eb6f
3 changed files with 194 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from unittest import mock
import pytest import pytest
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import override_settings from django.test import override_settings
@ -22,7 +23,10 @@ from documents.models import DocumentType
from documents.models import Note from documents.models import Note
from documents.models import StoragePath from documents.models import StoragePath
from documents.models import Tag from documents.models import Tag
from documents.models import Workflow
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
class TestDocumentSearchApi(DirectoriesMixin, APITestCase): class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
@ -1125,3 +1129,88 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
search_query("&ordering=-owner"), search_query("&ordering=-owner"),
[d3.id, d2.id, d1.id], [d3.id, d2.id, d1.id],
) )
def test_global_search(self):
"""
GIVEN:
- Multiple documents and objects
WHEN:
- Global search query is made
THEN:
- Appropriately filtered results are returned
"""
d1 = Document.objects.create(
title="invoice doc1",
content="the thing i bought at a shop and paid with bank account",
checksum="A",
pk=1,
)
d2 = Document.objects.create(
title="bank statement doc2",
content="things i paid for in august",
checksum="B",
pk=2,
)
d3 = Document.objects.create(
title="tax bill doc3",
content="no b word",
checksum="C",
pk=3,
)
with index.open_index_writer() as writer:
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
Correspondent.objects.create(name="correspondent 2")
document_type1 = DocumentType.objects.create(name="bank invoice")
DocumentType.objects.create(name="invoice")
storage_path1 = StoragePath.objects.create(name="bank path 1", path="path1")
StoragePath.objects.create(name="path 2", path="path2")
tag1 = Tag.objects.create(name="bank tag1")
Tag.objects.create(name="tag2")
user1 = User.objects.create_user("bank user1")
User.objects.create_user("user2")
group1 = Group.objects.create(name="bank group1")
Group.objects.create(name="group2")
mail_account1 = MailAccount.objects.create(name="bank mail account 1")
mail_account2 = MailAccount.objects.create(name="mail account 2")
mail_rule1 = MailRule.objects.create(
name="bank mail rule 1",
account=mail_account1,
action=MailRule.MailAction.MOVE,
)
MailRule.objects.create(
name="mail rule 2",
account=mail_account2,
action=MailRule.MailAction.MOVE,
)
custom_field1 = CustomField.objects.create(
name="bank custom field 1",
data_type=CustomField.FieldDataType.STRING,
)
CustomField.objects.create(
name="custom field 2",
data_type=CustomField.FieldDataType.INT,
)
workflow1 = Workflow.objects.create(name="bank workflow 1")
Workflow.objects.create(name="workflow 2")
response = self.client.get("/api/search/?query=bank")
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data
self.assertEqual(len(results["documents"]), 2)
self.assertNotEqual(results["documents"][0]["id"], d3.id)
self.assertNotEqual(results["documents"][1]["id"], d3.id)
self.assertEqual(results["correspondents"][0]["id"], correspondent1.id)
self.assertEqual(results["document_types"][0]["id"], document_type1.id)
self.assertEqual(results["storage_paths"][0]["id"], storage_path1.id)
self.assertEqual(results["tags"][0]["id"], tag1.id)
self.assertEqual(results["users"][0]["id"], user1.id)
self.assertEqual(results["groups"][0]["id"], group1.id)
self.assertEqual(results["mail_accounts"][0]["id"], mail_account1.id)
self.assertEqual(results["mail_rules"][0]["id"], mail_rule1.id)
self.assertEqual(results["custom_fields"][0]["id"], custom_field1.id)
self.assertEqual(results["workflows"][0]["id"], workflow1.id)

View File

@ -17,6 +17,7 @@ from urllib.parse import urlparse
import pathvalidate import pathvalidate
from django.apps import apps from django.apps import apps
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import connections from django.db import connections
@ -152,7 +153,13 @@ from paperless import version
from paperless.celery import app as celery_app from paperless.celery import app as celery_app
from paperless.config import GeneralConfig from paperless.config import GeneralConfig
from paperless.db import GnuPG from paperless.db import GnuPG
from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer
from paperless.views import StandardPagination from paperless.views import StandardPagination
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
from paperless_mail.serialisers import MailAccountSerializer
from paperless_mail.serialisers import MailRuleSerializer
if settings.AUDIT_LOG_ENABLED: if settings.AUDIT_LOG_ENABLED:
from auditlog.models import LogEntry from auditlog.models import LogEntry
@ -1119,6 +1126,98 @@ class SearchAutoCompleteView(APIView):
) )
class GlobalSearchView(PassUserMixin):
permission_classes = (IsAuthenticated,)
serializer_class = SearchResultSerializer
def get(self, request, *args, **kwargs):
query = request.query_params.get("query", None)
if query is None:
return HttpResponseBadRequest("Query required")
elif len(query) < 3:
return HttpResponseBadRequest("Query must be at least 3 characters")
docs = []
from documents import index
with index.open_index_searcher() as s:
q, _ = index.DelayedFullTextQuery(
s,
request.query_params,
10,
request.user,
)._get_query()
results = s.search(q, limit=10)
docs = Document.objects.filter(id__in=[r["id"] for r in results])
tags = Tag.objects.filter(name__contains=query)
correspondents = Correspondent.objects.filter(name__contains=query)
document_types = DocumentType.objects.filter(name__contains=query)
storage_paths = StoragePath.objects.filter(name__contains=query)
users = User.objects.filter(username__contains=query)
groups = Group.objects.filter(name__contains=query)
mail_rules = MailRule.objects.filter(name__contains=query)
mail_accounts = MailAccount.objects.filter(name__contains=query)
workflows = Workflow.objects.filter(name__contains=query)
custom_fields = CustomField.objects.filter(name__contains=query)
context = {
"request": request,
}
docs_serializer = DocumentSerializer(docs, many=True, context=context)
tags_serializer = TagSerializer(tags, many=True, context=context)
correspondents_serializer = CorrespondentSerializer(
correspondents,
many=True,
context=context,
)
document_types_serializer = DocumentTypeSerializer(
document_types,
many=True,
context=context,
)
storage_paths_serializer = StoragePathSerializer(
storage_paths,
many=True,
context=context,
)
users_serializer = UserSerializer(users, many=True, context=context)
groups_serializer = GroupSerializer(groups, many=True, context=context)
mail_rules_serializer = MailRuleSerializer(
mail_rules,
many=True,
context=context,
)
mail_accounts_serializer = MailAccountSerializer(
mail_accounts,
many=True,
context=context,
)
workflows_serializer = WorkflowSerializer(workflows, many=True, context=context)
custom_fields_serializer = CustomFieldSerializer(
custom_fields,
many=True,
context=context,
)
return Response(
{
"documents": docs_serializer.data,
"tags": tags_serializer.data,
"correspondents": correspondents_serializer.data,
"document_types": document_types_serializer.data,
"storage_paths": storage_paths_serializer.data,
"users": users_serializer.data,
"groups": groups_serializer.data,
"mail_rules": mail_rules_serializer.data,
"mail_accounts": mail_accounts_serializer.data,
"workflows": workflows_serializer.data,
"custom_fields": custom_fields_serializer.data,
},
)
class StatisticsView(APIView): class StatisticsView(APIView):
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)

View File

@ -21,6 +21,7 @@ from documents.views import BulkEditView
from documents.views import CorrespondentViewSet from documents.views import CorrespondentViewSet
from documents.views import CustomFieldViewSet from documents.views import CustomFieldViewSet
from documents.views import DocumentTypeViewSet from documents.views import DocumentTypeViewSet
from documents.views import GlobalSearchView
from documents.views import IndexView from documents.views import IndexView
from documents.views import LogViewSet from documents.views import LogViewSet
from documents.views import PostDocumentView from documents.views import PostDocumentView
@ -91,6 +92,11 @@ urlpatterns = [
SearchAutoCompleteView.as_view(), SearchAutoCompleteView.as_view(),
name="autocomplete", name="autocomplete",
), ),
re_path(
"^search/",
GlobalSearchView.as_view(),
name="global_search",
),
re_path("^statistics/", StatisticsView.as_view(), name="statistics"), re_path("^statistics/", StatisticsView.as_view(), name="statistics"),
re_path( re_path(
"^documents/post_document/", "^documents/post_document/",