Different way of mocking classifier

This commit is contained in:
shamoon 2024-02-15 17:37:18 -08:00
parent fc2c3dfc1a
commit 9c9b48695d
2 changed files with 19 additions and 18 deletions

View File

@ -1,12 +1,13 @@
import os import os
from unittest import mock from unittest import mock
from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import override_settings from django.test import override_settings
from rest_framework import status from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents.classifier import DocumentClassifier
from documents.classifier import load_classifier
from paperless import version from paperless import version
@ -92,23 +93,21 @@ class TestSystemStatus(APITestCase):
self.assertEqual(response.data["tasks"]["index_status"], "ERROR") self.assertEqual(response.data["tasks"]["index_status"], "ERROR")
self.assertIsNotNone(response.data["tasks"]["index_error"]) self.assertIsNotNone(response.data["tasks"]["index_error"])
@override_settings(MODEL_FILE="/tmp/does_not_exist") @override_settings(DATA_DIR="/tmp/does_not_exist/data/")
def test_system_status_classifier_ok(self): def test_system_status_classifier_ok(self):
with open(settings.MODEL_FILE, "w") as f: load_classifier()
f.write("test") test_classifier = DocumentClassifier()
f.close() test_classifier.save()
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["classifier_status"], "OK")
self.assertIsNone(response.data["tasks"]["classifier_error"])
@override_settings(MODEL_FILE="/tmp/does_not_exist")
@mock.patch("documents.classifier.load_classifier")
def test_system_status_classifier_error(self, mock_load_classifier):
mock_load_classifier.side_effect = Exception("Classifier error")
self.client.force_login(self.user) self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT) response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["classifier_status"], "ERROR") self.assertEqual(response.data["tasks"]["classifier_status"], "OK")
self.assertIsNotNone(response.data["tasks"]["classifier_error"]) self.assertIsNone(response.data["tasks"]["classifier_error"])
def test_system_status_classifier_error(self):
with override_settings(MODEL_FILE="does_not_exist"):
self.client.force_login(self.user)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["tasks"]["classifier_status"], "ERROR")
self.assertIsNotNone(response.data["tasks"]["classifier_error"])

View File

@ -1620,7 +1620,9 @@ class SystemStatusView(GenericAPIView, PassUserMixin):
classifier_error = None classifier_error = None
try: try:
load_classifier() classifier = load_classifier()
if classifier is None:
raise Exception("Classifier not loaded")
classifier_status = "OK" classifier_status = "OK"
task_result_model = apps.get_model("django_celery_results", "taskresult") task_result_model = apps.get_model("django_celery_results", "taskresult")
result = ( result = (