-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathUnitTest.py
More file actions
155 lines (128 loc) · 7.03 KB
/
UnitTest.py
File metadata and controls
155 lines (128 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
import networkx as nx
import cv2
import tensorflow as tf
from datetime import datetime, timezone
# --- Unit Tests ---
class TestKalmanTracker(unittest.TestCase):
def setUp(self):
# Mock the logger to avoid actual logging during tests
patcher = patch('logging.getLogger')
self.mock_logger = patcher.start()
self.addCleanup(patcher.stop)
self.mock_logger.return_value = MagicMock()
self.initial_state = np.array([1.0, 2.0, 0.5, 0.2], dtype=np.float32) # x, y, vx, vy
self.tracker = KalmanTracker(self.initial_state)
self.tracker.filter.statePre = self.initial_state.reshape(-1, 1) # Required for test consistency
def test_init(self):
self.assertIsNotNone(self.tracker.filter)
self.assertTrue(np.array_equal(self.tracker.filter.statePre, self.initial_state.reshape(-1, 1)))
def test_init_invalid_state(self):
with self.assertRaises(ValueError):
KalmanTracker(np.array([1.0, 2.0], dtype=np.float32)) # Wrong dimensions
with self.assertRaises(ValueError):
KalmanTracker(None) # None input
def test_predict(self):
predicted_state = self.tracker.predict()
self.assertIsNotNone(predicted_state)
self.assertEqual(predicted_state.shape, (4, 1)) # Check shape
# TODO: add more specific checks on the predicted values if needed
def test_predict_opencv_error(self):
# Mock OpenCV predict to raise an error
self.tracker.filter.predict = MagicMock(side_effect=cv2.error)
predicted_state = self.tracker.predict()
self.assertIsNone(predicted_state) # Or handle as appropriate for your error handling
def test_update(self):
measurement = np.array([1.2, 2.1], dtype=np.float32) # new x, y measurement
updated_state = self.tracker.update(measurement)
self.assertIsNotNone(updated_state)
self.assertEqual(updated_state.shape, (4, 1)) # Check shape
# TODO: add more specific checks on updated values if needed
def test_update_invalid_measurement(self):
with self.assertRaises(ValueError):
self.tracker.update(np.array([1.0], dtype=np.float32)) # Wrong dimension
with self.assertRaises(ValueError):
self.tracker.update(None)
def test_update_opencv_error(self):
# Mock OpenCV correct to raise an error
self.tracker.filter.correct = MagicMock(side_effect=cv2.error)
measurement = np.array([1.2, 2.1], dtype=np.float32)
updated_state = self.tracker.update(measurement)
self.assertIsNotNone(updated_state) # Should return previous state
# Add assertion to check it returns previous state if needed
class TestObjectBehaviorLearner(unittest.TestCase):
def setUp(self):
patcher = patch('logging.getLogger')
self.mock_logger = patcher.start()
self.addCleanup(patcher.stop)
self.mock_logger.return_value = MagicMock()
# Mock the TransformerBlock layer to return its input directly
def mock_transformer_block(x, training): return x
patcher2 = patch('robots.TransformerBlock', return_value=MagicMock(side_effect=mock_transformer_block)) # Assuming robots.py
self.mock_transformer = patcher2.start()
self.addCleanup(patcher2.stop)
self.learner = ObjectBehaviorLearner(state_dimension=4, d_model=64, num_heads=2, dff=128, num_layers=1)
def test_init(self):
self.assertIsInstance(self.learner.local_object_graph_cache, nx.Graph)
self.assertIsInstance(self.learner.model, tf.keras.Model)
def test_init_invalid_config(self):
# Test with invalid Transformer params (ensure d_model % num_heads == 0)
with self.assertRaises(Exception):
ObjectBehaviorLearner(state_dimension=4, d_model=65, num_heads=2, dff=128, num_layers=1)
# The logger should capture the warning, but we don't assert that in this unit test
def test_train_predict(self):
# Generate dummy data
past_states = np.random.rand(100, 5, 4) # 100 samples, 5 time steps, 4 state dims
target_states = np.random.rand(100, 10, 4) # 100 samples, 10 horizon, 4 state dims
# Training
self.learner.train(past_states, target_states)
# Check if fit was called (using mock from setUp)
# The mock's "fit" method will be called during learner.train
# To verify, we would need to patch the learner.model.fit itself, but we are using a global patch on TransformerBlock
# Instead, check that train doesn't crash
# Prediction
predictions = self.learner.predict(past_states[:10]) # Predict for a subset
self.assertIsNotNone(predictions)
self.assertEqual(predictions.shape, (10, 10, 4)) # Check shape
def test_train_predict_invalid_data(self):
# Incomplete - test for training or prediction with invalid data shapes (e.g., None, wrong dimensions).
# You might need try-except blocks within the test to catch potential errors and assert they are handled correctly.
pass
def test_update_local_cache(self):
object_id = "obj_123"
state = np.array([[1.0, 2.0, 0.5, 0.1]], dtype=np.float32)
timestamp = datetime.now(timezone.utc).isoformat()
self.learner.update_local_cache(object_id, state, timestamp)
self.assertTrue(object_id in self.learner.local_object_graph_cache.nodes)
node_data = self.learner.local_object_graph_cache.nodes[object_id]
self.assertEqual(node_data['state'], state.flatten().tolist())
self.assertEqual(node_data['timestamp'], timestamp)
# Update existing node
new_state = np.array([[1.1, 2.1, 0.6, 0.2]], dtype=np.float32)
new_timestamp = datetime.now(timezone.utc).isoformat()
self.learner.update_local_cache(object_id, new_state, new_timestamp)
node_data = self.learner.local_object_graph_cache.nodes[object_id]
self.assertEqual(node_data['state'], new_state.flatten().tolist())
self.assertEqual(node_data['timestamp'], new_timestamp)
class TestPredictionCombiner(unittest.TestCase):
def setUp(self):
patcher = patch('logging.getLogger')
self.mock_logger = patcher.start()
self.addCleanup(patcher.stop)
self.mock_logger.return_value = MagicMock()
self.transformer_output_dim = 4 # Example
self.context_dim = 10 # Example
self.prediction_horizon = 10 # Example
self.combiner = PredictionCombiner(self.transformer_output_dim, self.context_dim, output_dim=2,
prediction_horizon=self.prediction_horizon, use_confidence_weighting=True)
def test_init(self):
self.assertIsInstance(self.combiner, tf.keras.Model)
# Check that layers were created (not exhaustive)
self.assertIsNotNone(self.combiner.dense1)
self.assertIsNotNone(self.combiner.layer_norm)
def test_call(self):
# Dummy Inputs
num_samples = 5
transformer_predictions = tf.random.normal((num