From 9c9b48695d793d5149b97870a0751789fc667f89 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:37:18 -0800 Subject: [PATCH] Different way of mocking classifier --- src/documents/tests/test_api_status.py | 33 +++++++++++++------------- src/documents/views.py | 4 +++- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/documents/tests/test_api_status.py b/src/documents/tests/test_api_status.py index a64c8c1b4..fbb3af8b9 100644 --- a/src/documents/tests/test_api_status.py +++ b/src/documents/tests/test_api_status.py @@ -1,12 +1,13 @@ import os from unittest import mock -from django.conf import settings from django.contrib.auth.models import User from django.test import override_settings from rest_framework import status from rest_framework.test import APITestCase +from documents.classifier import DocumentClassifier +from documents.classifier import load_classifier from paperless import version @@ -92,23 +93,21 @@ class TestSystemStatus(APITestCase): self.assertEqual(response.data["tasks"]["index_status"], "ERROR") self.assertIsNotNone(response.data["tasks"]["index_error"]) - @override_settings(MODEL_FILE="/tmp/does_not_exist") + @override_settings(DATA_DIR="/tmp/does_not_exist/data/") def test_system_status_classifier_ok(self): - with open(settings.MODEL_FILE, "w") as f: - f.write("test") - f.close() - self.client.force_login(self.user) - response = self.client.get(self.ENDPOINT) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data["tasks"]["classifier_status"], "OK") - self.assertIsNone(response.data["tasks"]["classifier_error"]) - - @override_settings(MODEL_FILE="/tmp/does_not_exist") - @mock.patch("documents.classifier.load_classifier") - def test_system_status_classifier_error(self, mock_load_classifier): - mock_load_classifier.side_effect = Exception("Classifier error") + load_classifier() + test_classifier = DocumentClassifier() + test_classifier.save() self.client.force_login(self.user) response = self.client.get(self.ENDPOINT) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data["tasks"]["classifier_status"], "ERROR") - self.assertIsNotNone(response.data["tasks"]["classifier_error"]) + self.assertEqual(response.data["tasks"]["classifier_status"], "OK") + self.assertIsNone(response.data["tasks"]["classifier_error"]) + + def test_system_status_classifier_error(self): + with override_settings(MODEL_FILE="does_not_exist"): + self.client.force_login(self.user) + response = self.client.get(self.ENDPOINT) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["tasks"]["classifier_status"], "ERROR") + self.assertIsNotNone(response.data["tasks"]["classifier_error"]) diff --git a/src/documents/views.py b/src/documents/views.py index 4003c33df..5abb349ac 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1620,7 +1620,9 @@ class SystemStatusView(GenericAPIView, PassUserMixin): classifier_error = None try: - load_classifier() + classifier = load_classifier() + if classifier is None: + raise Exception("Classifier not loaded") classifier_status = "OK" task_result_model = apps.get_model("django_celery_results", "taskresult") result = (