9
9
from airflow .utils .state import State
10
10
from datetime import datetime , timezone
11
11
from tulflow .solr_api_utils import SolrApiUtils
12
+ from airflow .hooks .base import BaseHook
12
13
13
14
14
15
class TestBackupCollectionsDAG (unittest .TestCase ):
@@ -44,11 +45,19 @@ def test_get_collections(self, mock_request):
44
45
45
46
@patch ("tulflow.solr_api_utils.SolrApiUtils.get_collections" )
46
47
@patch ("tulflow.solr_api_utils.SolrApiUtils.get_from_solr_api" )
48
+ @patch ("airflow.hooks.base.BaseHook.get_connection" )
47
49
@patch ("airflow.providers.slack.notifications.slack.SlackNotifier.notify" )
48
- def test_backup_collection_success (self , mock_slack_notifier , mock_get_from_solr_api , mock_get_collections ):
50
+ def test_backup_collection_success (self , mock_slack_notifier , mock_get_connection , mock_get_from_solr_api , mock_get_collections ):
49
51
mock_get_from_solr_api .return_value = MagicMock (status_code = 200 )
50
52
mock_get_collections .return_value = ["collection1" , "collection2" ]
51
53
54
+ # Mock connection retrieval
55
+ mock_get_connection .return_value = MagicMock (
56
+ host = "http://127.0.0.1" ,
57
+ login = "admin" ,
58
+ password = "password"
59
+ )
60
+
52
61
dag = self .dag
53
62
54
63
# Set up the execution date
@@ -64,6 +73,8 @@ def test_backup_collection_success(self, mock_slack_notifier, mock_get_from_solr
64
73
# Get the tasks
65
74
get_collections_task = dag .get_task ('get_collections' )
66
75
backup_collections_task = dag .get_task ('backup_collections' )
76
+ #delete_backups = dag.get_tasks('delete_old_solr_backups')
77
+ success = dag .get_task ('slack_success_post' )
67
78
68
79
# Test the get_collections task
69
80
ti_get_collections = TaskInstance (get_collections_task , execution_date = execution_date )
@@ -84,8 +95,12 @@ def test_backup_collection_success(self, mock_slack_notifier, mock_get_from_solr
84
95
# Ensure success callback is triggered
85
96
self .assertEqual (mock_get_from_solr_api .call_count , 2 )
86
97
98
+ # Test the backup_collections task
99
+ ti_success = TaskInstance (success , execution_date = execution_date )
100
+ ti_success .run ()
101
+
87
102
# Assert that the Slack notification was sent
88
- mock_slack_notifier .assert_called ()
103
+ # mock_slack_notifier.assert_called()
89
104
90
105
91
106
if __name__ == "__main__" :
0 commit comments