Basic global search api endpoint
This commit is contained in:
parent
7a0334f353
commit
b7a0c0eb6f
@ -4,6 +4,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import Permission
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
@ -22,7 +23,10 @@ from documents.models import DocumentType
|
||||
from documents.models import Note
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
|
||||
|
||||
class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
@ -1125,3 +1129,88 @@ class TestDocumentSearchApi(DirectoriesMixin, APITestCase):
|
||||
search_query("&ordering=-owner"),
|
||||
[d3.id, d2.id, d1.id],
|
||||
)
|
||||
|
||||
def test_global_search(self):
|
||||
"""
|
||||
GIVEN:
|
||||
- Multiple documents and objects
|
||||
WHEN:
|
||||
- Global search query is made
|
||||
THEN:
|
||||
- Appropriately filtered results are returned
|
||||
"""
|
||||
d1 = Document.objects.create(
|
||||
title="invoice doc1",
|
||||
content="the thing i bought at a shop and paid with bank account",
|
||||
checksum="A",
|
||||
pk=1,
|
||||
)
|
||||
d2 = Document.objects.create(
|
||||
title="bank statement doc2",
|
||||
content="things i paid for in august",
|
||||
checksum="B",
|
||||
pk=2,
|
||||
)
|
||||
d3 = Document.objects.create(
|
||||
title="tax bill doc3",
|
||||
content="no b word",
|
||||
checksum="C",
|
||||
pk=3,
|
||||
)
|
||||
|
||||
with index.open_index_writer() as writer:
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
|
||||
correspondent1 = Correspondent.objects.create(name="bank correspondent 1")
|
||||
Correspondent.objects.create(name="correspondent 2")
|
||||
document_type1 = DocumentType.objects.create(name="bank invoice")
|
||||
DocumentType.objects.create(name="invoice")
|
||||
storage_path1 = StoragePath.objects.create(name="bank path 1", path="path1")
|
||||
StoragePath.objects.create(name="path 2", path="path2")
|
||||
tag1 = Tag.objects.create(name="bank tag1")
|
||||
Tag.objects.create(name="tag2")
|
||||
user1 = User.objects.create_user("bank user1")
|
||||
User.objects.create_user("user2")
|
||||
group1 = Group.objects.create(name="bank group1")
|
||||
Group.objects.create(name="group2")
|
||||
mail_account1 = MailAccount.objects.create(name="bank mail account 1")
|
||||
mail_account2 = MailAccount.objects.create(name="mail account 2")
|
||||
mail_rule1 = MailRule.objects.create(
|
||||
name="bank mail rule 1",
|
||||
account=mail_account1,
|
||||
action=MailRule.MailAction.MOVE,
|
||||
)
|
||||
MailRule.objects.create(
|
||||
name="mail rule 2",
|
||||
account=mail_account2,
|
||||
action=MailRule.MailAction.MOVE,
|
||||
)
|
||||
custom_field1 = CustomField.objects.create(
|
||||
name="bank custom field 1",
|
||||
data_type=CustomField.FieldDataType.STRING,
|
||||
)
|
||||
CustomField.objects.create(
|
||||
name="custom field 2",
|
||||
data_type=CustomField.FieldDataType.INT,
|
||||
)
|
||||
workflow1 = Workflow.objects.create(name="bank workflow 1")
|
||||
Workflow.objects.create(name="workflow 2")
|
||||
|
||||
response = self.client.get("/api/search/?query=bank")
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
results = response.data
|
||||
self.assertEqual(len(results["documents"]), 2)
|
||||
self.assertNotEqual(results["documents"][0]["id"], d3.id)
|
||||
self.assertNotEqual(results["documents"][1]["id"], d3.id)
|
||||
self.assertEqual(results["correspondents"][0]["id"], correspondent1.id)
|
||||
self.assertEqual(results["document_types"][0]["id"], document_type1.id)
|
||||
self.assertEqual(results["storage_paths"][0]["id"], storage_path1.id)
|
||||
self.assertEqual(results["tags"][0]["id"], tag1.id)
|
||||
self.assertEqual(results["users"][0]["id"], user1.id)
|
||||
self.assertEqual(results["groups"][0]["id"], group1.id)
|
||||
self.assertEqual(results["mail_accounts"][0]["id"], mail_account1.id)
|
||||
self.assertEqual(results["mail_rules"][0]["id"], mail_rule1.id)
|
||||
self.assertEqual(results["custom_fields"][0]["id"], custom_field1.id)
|
||||
self.assertEqual(results["workflows"][0]["id"], workflow1.id)
|
||||
|
@ -17,6 +17,7 @@ from urllib.parse import urlparse
|
||||
import pathvalidate
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import Group
|
||||
from django.contrib.auth.models import User
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import connections
|
||||
@ -152,7 +153,13 @@ from paperless import version
|
||||
from paperless.celery import app as celery_app
|
||||
from paperless.config import GeneralConfig
|
||||
from paperless.db import GnuPG
|
||||
from paperless.serialisers import GroupSerializer
|
||||
from paperless.serialisers import UserSerializer
|
||||
from paperless.views import StandardPagination
|
||||
from paperless_mail.models import MailAccount
|
||||
from paperless_mail.models import MailRule
|
||||
from paperless_mail.serialisers import MailAccountSerializer
|
||||
from paperless_mail.serialisers import MailRuleSerializer
|
||||
|
||||
if settings.AUDIT_LOG_ENABLED:
|
||||
from auditlog.models import LogEntry
|
||||
@ -1119,6 +1126,98 @@ class SearchAutoCompleteView(APIView):
|
||||
)
|
||||
|
||||
|
||||
class GlobalSearchView(PassUserMixin):
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = SearchResultSerializer
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
query = request.query_params.get("query", None)
|
||||
if query is None:
|
||||
return HttpResponseBadRequest("Query required")
|
||||
elif len(query) < 3:
|
||||
return HttpResponseBadRequest("Query must be at least 3 characters")
|
||||
|
||||
docs = []
|
||||
from documents import index
|
||||
|
||||
with index.open_index_searcher() as s:
|
||||
q, _ = index.DelayedFullTextQuery(
|
||||
s,
|
||||
request.query_params,
|
||||
10,
|
||||
request.user,
|
||||
)._get_query()
|
||||
results = s.search(q, limit=10)
|
||||
docs = Document.objects.filter(id__in=[r["id"] for r in results])
|
||||
|
||||
tags = Tag.objects.filter(name__contains=query)
|
||||
correspondents = Correspondent.objects.filter(name__contains=query)
|
||||
document_types = DocumentType.objects.filter(name__contains=query)
|
||||
storage_paths = StoragePath.objects.filter(name__contains=query)
|
||||
users = User.objects.filter(username__contains=query)
|
||||
groups = Group.objects.filter(name__contains=query)
|
||||
mail_rules = MailRule.objects.filter(name__contains=query)
|
||||
mail_accounts = MailAccount.objects.filter(name__contains=query)
|
||||
workflows = Workflow.objects.filter(name__contains=query)
|
||||
custom_fields = CustomField.objects.filter(name__contains=query)
|
||||
|
||||
context = {
|
||||
"request": request,
|
||||
}
|
||||
|
||||
docs_serializer = DocumentSerializer(docs, many=True, context=context)
|
||||
tags_serializer = TagSerializer(tags, many=True, context=context)
|
||||
correspondents_serializer = CorrespondentSerializer(
|
||||
correspondents,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
document_types_serializer = DocumentTypeSerializer(
|
||||
document_types,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
storage_paths_serializer = StoragePathSerializer(
|
||||
storage_paths,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
users_serializer = UserSerializer(users, many=True, context=context)
|
||||
groups_serializer = GroupSerializer(groups, many=True, context=context)
|
||||
mail_rules_serializer = MailRuleSerializer(
|
||||
mail_rules,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
mail_accounts_serializer = MailAccountSerializer(
|
||||
mail_accounts,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
workflows_serializer = WorkflowSerializer(workflows, many=True, context=context)
|
||||
custom_fields_serializer = CustomFieldSerializer(
|
||||
custom_fields,
|
||||
many=True,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"documents": docs_serializer.data,
|
||||
"tags": tags_serializer.data,
|
||||
"correspondents": correspondents_serializer.data,
|
||||
"document_types": document_types_serializer.data,
|
||||
"storage_paths": storage_paths_serializer.data,
|
||||
"users": users_serializer.data,
|
||||
"groups": groups_serializer.data,
|
||||
"mail_rules": mail_rules_serializer.data,
|
||||
"mail_accounts": mail_accounts_serializer.data,
|
||||
"workflows": workflows_serializer.data,
|
||||
"custom_fields": custom_fields_serializer.data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class StatisticsView(APIView):
|
||||
permission_classes = (IsAuthenticated,)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from documents.views import BulkEditView
|
||||
from documents.views import CorrespondentViewSet
|
||||
from documents.views import CustomFieldViewSet
|
||||
from documents.views import DocumentTypeViewSet
|
||||
from documents.views import GlobalSearchView
|
||||
from documents.views import IndexView
|
||||
from documents.views import LogViewSet
|
||||
from documents.views import PostDocumentView
|
||||
@ -91,6 +92,11 @@ urlpatterns = [
|
||||
SearchAutoCompleteView.as_view(),
|
||||
name="autocomplete",
|
||||
),
|
||||
re_path(
|
||||
"^search/",
|
||||
GlobalSearchView.as_view(),
|
||||
name="global_search",
|
||||
),
|
||||
re_path("^statistics/", StatisticsView.as_view(), name="statistics"),
|
||||
re_path(
|
||||
"^documents/post_document/",
|
||||
|
Loading…
x
Reference in New Issue
Block a user